Skip to content

Commit 2639cf0

Browse files
author
James Reed
authored
Merge pull request #919 from jamesr66a/fix_wrap
[FX] Fix wrap output example
2 parents cbb760d + 6126062 commit 2639cf0

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

fx/wrap_output_dynamically.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,19 @@ def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> Graph
6767
# Call the specified activation function using the Proxy wrapper for
6868
# `output_op`. The result of this call is another Proxy, which we
6969
# can hook into our existing Graph.
70-
with traced.graph.inserting_before(wrap_node):
70+
with traced.graph.inserting_after(wrap_node):
7171
fn_impl_output_node = fn_impl_traced(wrap_proxy)
7272
new_args = (fn_impl_output_node.node,)
7373
output_node.args = new_args
7474

75+
m.recompile()
76+
7577

7678
# Example call
79+
x, y = torch.randn(5, 3), torch.randn(5, 3)
80+
orig_output = traced(x, y)
81+
7782
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)
83+
new_output = traced(x, y)
84+
85+
torch.testing.assert_allclose(new_output, torch.nn.LeakyReLU()(orig_output))

0 commit comments

Comments
 (0)