32
32
from executorch .exir .verification .verifier import (
33
33
EXIRATenDialectVerifier ,
34
34
EXIREdgeDialectVerifier ,
35
+ get_aten_verifier ,
35
36
)
36
37
from torch ._export import ExportedProgram
37
38
from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
44
45
Val = Any
45
46
46
47
48
+ def _copy_module (new_prog , new_gm ):
49
+ new_prog .meta .update (new_gm .meta )
50
+ new_prog .graph = new_gm .graph
51
+ submodules = [name for name , _ in new_prog .named_children ()]
52
+ for name in submodules :
53
+ delattr (new_prog , name )
54
+ for name , mod in new_gm .named_children ():
55
+ setattr (new_prog , name , mod )
56
+ for node in new_gm .graph .nodes :
57
+ if node .op == "get_attr" :
58
+ t = getattr (new_gm , node .target , None )
59
+ if isinstance (t , torch .Tensor ):
60
+ setattr (new_prog , node .target , t )
61
+
62
+
47
63
def lift_constant_tensor_pass (ep ):
48
64
"""
49
65
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
@@ -131,6 +147,7 @@ def __init__(
131
147
equality_constraints ,
132
148
module_call_graph ,
133
149
example_inputs ,
150
+ verifier ,
134
151
):
135
152
super ().__init__ (
136
153
root ,
@@ -141,8 +158,8 @@ def __init__(
141
158
equality_constraints ,
142
159
module_call_graph ,
143
160
example_inputs ,
161
+ verifier ,
144
162
)
145
- self ._dialect = "HACKED_ATEN"
146
163
147
164
def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
148
165
import torch ._export .error as error
@@ -245,9 +262,15 @@ def to_executorch(
245
262
if not self .after_to_edge_passes :
246
263
raise RuntimeError ("Must run to_edge before to_executorch." )
247
264
config = config or ExecutorchBackendConfig ()
248
- ep = self .exported_program
249
- new_prog = ep ._transform (* edge_to_executorch_passes (config ))
250
- new_prog = ExirExportedProgram (new_prog , self .after_to_edge_passes )
265
+ new_gm = self .exported_program .graph_module
266
+ for p in edge_to_executorch_passes (config ):
267
+ new_gm_res = p (new_gm )
268
+ assert new_gm_res is not None
269
+ new_gm = new_gm_res .graph_module
270
+ new_prog = ExirExportedProgram (
271
+ copy .deepcopy (self .exported_program ), self .after_to_edge_passes
272
+ )
273
+ _copy_module (new_prog .exported_program .graph_module , new_gm )
251
274
executorch_prog = ExecutorchProgram (
252
275
new_prog ,
253
276
emit_stacktrace = config .emit_stacktrace ,
@@ -256,9 +279,7 @@ def to_executorch(
256
279
constant_tensor_alignment = config .constant_tensor_alignment ,
257
280
delegate_alignment = config .delegate_alignment ,
258
281
)
259
- executorch_prog .graph_module .meta .update (
260
- new_prog .exported_program .graph_module .meta
261
- )
282
+ executorch_prog .graph_module .meta .update (new_gm .meta )
262
283
executorch_prog .graph_module .meta .update (
263
284
self .exported_program .graph_module .meta
264
285
)
@@ -356,6 +377,22 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
356
377
)
357
378
raise
358
379
380
+ dialect = ep .exported_program .dialect
381
+ if dialect == "ATEN" :
382
+ ep = ExirExportedProgram (
383
+ ExportedProgram (
384
+ ep .exported_program .graph_module ,
385
+ ep .exported_program .graph_module .graph ,
386
+ ep .exported_program .graph_signature ,
387
+ ep .exported_program .state_dict ,
388
+ ep .exported_program .range_constraints ,
389
+ ep .exported_program .equality_constraints ,
390
+ ep .exported_program .module_call_graph ,
391
+ ep .exported_program .example_inputs ,
392
+ verifier = get_aten_verifier (enable = config ._check_ir_validity ),
393
+ ),
394
+ False ,
395
+ )
359
396
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
360
397
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
361
398
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
@@ -364,16 +401,23 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
364
401
post_op_replace_passes = aten_to_edge_passes .passes [- 2 :]
365
402
366
403
new_ep = copy .deepcopy (ep ).transform (* pre_op_replace_passes )
367
- if new_ep . exported_program . dialect == "ATEN" :
404
+ if dialect == "ATEN" :
368
405
new_ep .exported_program = lift_constant_tensor_pass (new_ep .exported_program )
369
406
407
+ new_gm = new_ep .exported_program .graph_module
370
408
if config ._use_edge_ops :
371
- new_ep = new_ep .transform (OpReplacePass ())
409
+ new_gm_res = OpReplacePass ()(new_gm )
410
+ assert new_gm_res is not None
411
+ new_gm = new_gm_res .graph_module
412
+
413
+ for p in post_op_replace_passes :
414
+ new_gm_res = p (new_gm )
415
+ assert new_gm_res is not None
416
+ new_gm = new_gm_res .graph_module
372
417
373
- new_ep = new_ep .transform (* post_op_replace_passes )
374
418
new_ep .exported_program = ExportedProgram (
375
- new_ep . exported_program . graph_module ,
376
- new_ep . exported_program .graph ,
419
+ new_gm ,
420
+ new_gm .graph ,
377
421
new_ep .exported_program .graph_signature ,
378
422
new_ep .exported_program .state_dict ,
379
423
new_ep .exported_program .range_constraints ,
@@ -386,11 +430,6 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
386
430
class_only = True ,
387
431
),
388
432
)
389
- if config ._check_ir_validity :
390
- # TODO(zhxchen17) Remove this call after we turn on verifier in ctor.
391
- EXIREdgeDialectVerifier (check_edge_ops = config ._use_edge_ops )(
392
- new_ep .exported_program .graph_module
393
- )
394
433
new_ep .after_to_edge_passes = True
395
434
return new_ep
396
435
@@ -623,8 +662,18 @@ def multi_method_program_to_executorch(
623
662
) -> MultiMethodExecutorchProgram :
624
663
config = config or ExecutorchBackendConfig ()
625
664
passes = edge_to_executorch_passes (config )
665
+ res = {}
666
+ for method_name , prog in edge_dialect_program ._method_to_program .items ():
667
+ new_prog = copy .deepcopy (prog )
668
+ gm = prog .exported_program .graph_module
669
+ for p in passes :
670
+ gm_res = p (gm )
671
+ assert gm_res is not None
672
+ gm = gm_res .graph_module
673
+ _copy_module (new_prog .exported_program .graph_module , gm )
674
+ res [method_name ] = new_prog
626
675
return MultiMethodExecutorchProgram (
627
- executorch_dialect_program = edge_dialect_program . transform ( * passes ),
676
+ executorch_dialect_program = MultiMethodExirExportedProgram ( res ),
628
677
emit_stacktrace = config .emit_stacktrace ,
629
678
extract_segments = config .extract_segments ,
630
679
segment_alignment = config .segment_alignment ,
@@ -674,8 +723,6 @@ def to_edge(
674
723
logging .info (f"Input program { name } is not in ATen dialect." )
675
724
raise e
676
725
677
- op_replace_pass = [OpReplacePass ()] if config ._use_edge_ops else []
678
-
679
726
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
680
727
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
681
728
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
@@ -686,9 +733,31 @@ def to_edge(
686
733
ReplaceViewOpsWithViewCopyOpsPass ()
687
734
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
688
735
passes .extend (aten_to_edge_passes .passes [:- 2 ])
689
- passes .extend (op_replace_pass )
690
- passes .extend (aten_to_edge_passes .passes [- 2 :])
691
736
edge_program = program ._transform (* passes )
737
+ if config ._use_edge_ops :
738
+ gm_res = OpReplacePass ()(edge_program .graph_module )
739
+ assert gm_res is not None
740
+ gm = gm_res .graph_module
741
+ else :
742
+ gm = edge_program .graph_module
743
+ edge_program = ExportedProgram (
744
+ root = gm ,
745
+ graph = gm .graph ,
746
+ graph_signature = edge_program .graph_signature ,
747
+ state_dict = edge_program .state_dict ,
748
+ range_constraints = edge_program .range_constraints ,
749
+ equality_constraints = edge_program .equality_constraints ,
750
+ module_call_graph = edge_program .module_call_graph ,
751
+ example_inputs = edge_program .example_inputs ,
752
+ verifier = EXIREdgeDialectVerifier (
753
+ check_edge_ops = config ._use_edge_ops ,
754
+ enable = config ._check_ir_validity ,
755
+ class_only = True ,
756
+ ),
757
+ )
758
+ passes = []
759
+ passes .extend (aten_to_edge_passes .passes [- 2 :])
760
+ edge_program = edge_program ._transform (* passes )
692
761
try :
693
762
EXIREdgeDialectVerifier (
694
763
check_edge_ops = config ._use_edge_ops , enable = config ._check_ir_validity
@@ -852,7 +921,13 @@ def to_executorch(
852
921
853
922
execution_programs : Dict [str , ExportedProgram ] = {}
854
923
for name , program in self ._edge_programs .items ():
855
- new_prog = program ._transform (* edge_to_executorch_passes (config ))
924
+ new_gm = program .graph_module
925
+ for p in edge_to_executorch_passes (config ):
926
+ new_gm_res = p (new_gm )
927
+ assert new_gm_res is not None
928
+ new_gm = new_gm_res .graph_module
929
+ new_prog = copy .deepcopy (program )
930
+ _copy_module (new_prog .graph_module , new_gm )
856
931
execution_programs [name ] = new_prog
857
932
858
933
return ExecutorchProgramManager (
0 commit comments