|
22 | 22 | from typing import List |
23 | 23 |
|
24 | 24 | import torch |
25 | | -from torch.fx import GraphModule, Node |
| 25 | +from torch.fx import GraphModule, Node, Proxy |
26 | 26 | from transformers.file_utils import add_end_docstrings |
27 | 27 |
|
28 | 28 |
|
29 | | -try: |
30 | | - from transformers.utils.fx import _gen_constructor_wrapper |
31 | | -except ImportError: |
32 | | - from transformers.utils.fx import gen_constructor_wrapper |
33 | | - |
34 | | - def _gen_constructor_wrapper(*args, **kwargs): |
35 | | - wrapper, target = gen_constructor_wrapper(*args, **kwargs) |
36 | | - |
37 | | - def wrapper_with_forced_tracing(*_args, **_kwargs): |
38 | | - import torch.fx._symbolic_trace |
39 | | - |
40 | | - original_flag = torch.fx._symbolic_trace._is_fx_tracing_flag |
41 | | - torch.fx._symbolic_trace._is_fx_tracing_flag = True |
42 | | - out = wrapper(*_args, **_kwargs) |
43 | | - torch.fx._symbolic_trace._is_fx_tracing_flag = original_flag |
44 | | - return out |
45 | | - |
46 | | - return wrapper_with_forced_tracing, target |
47 | | - |
48 | | - |
49 | 29 | _ATTRIBUTES_DOCSTRING = r""" |
50 | 30 | Attributes: |
51 | 31 | preserves_computation (`bool`, defaults to `False`): |
@@ -819,3 +799,24 @@ def reverse(self, graph_module): |
819 | 799 | return ComposeTransformation._reverse_composition(graph_module) |
820 | 800 |
|
821 | 801 | return ComposeTransformation() |
| 802 | + |
| 803 | + |
| 804 | +def _gen_constructor_wrapper(target): |
| 805 | + @functools.wraps(target) |
| 806 | + def wrapper(*args, **kwargs): |
| 807 | + proxy = None |
| 808 | + |
| 809 | + def check_has_proxy(v): |
| 810 | + if isinstance(v, Proxy): |
| 811 | + nonlocal proxy |
| 812 | + proxy = v |
| 813 | + |
| 814 | + torch.fx.node.map_aggregate(args, check_has_proxy) |
| 815 | + torch.fx.node.map_aggregate(kwargs, check_has_proxy) |
| 816 | + |
| 817 | + if proxy is not None: |
| 818 | + return proxy.tracer.create_proxy("call_function", target, args, kwargs) |
| 819 | + else: |
| 820 | + return target(*args, **kwargs) |
| 821 | + |
| 822 | + return wrapper, target |
0 commit comments