Skip to content

Commit 6765b94

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Enable verifier [2/n] (#1155)
Summary: Pull Request resolved: #1155 X-link: pytorch/pytorch#113075 Turn on verifier check for exportec program ctor. Note that this effectively detect a large surface of spec violations, so we also spend some time fixing them one by one in this diff. Reviewed By: angelayi Differential Revision: D51014944 fbshipit-source-id: 9a9e689ec9c8561a9c5cd4b6bd8836809ae5ff30
1 parent b182621 commit 6765b94

File tree

14 files changed

+233
-149
lines changed

14 files changed

+233
-149
lines changed

backends/example/example_backend.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ def preprocess(
3232
print("entering the lowerable parts in ExampleBackend.preprocess....")
3333

3434
copy_edge_program = copy.deepcopy(edge_program)
35-
copy_edge_program._transform(
36-
PermuteMemoryFormatsPass(),
37-
MergeToDimPass(),
38-
)
39-
processed_bytes = str(copy_edge_program.graph)
35+
graph_module = copy_edge_program.graph_module
36+
graph_module_res = PermuteMemoryFormatsPass()(graph_module)
37+
assert graph_module_res is not None
38+
graph_module_res = MergeToDimPass()(graph_module_res.graph_module)
39+
assert graph_module_res is not None
40+
processed_bytes = str(graph_module_res.graph_module.graph)
4041
return PreprocessResult(bytes(processed_bytes, encoding="utf8"))

exir/backend/backend_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,6 @@ def to_backend(
326326
copy.deepcopy(edge_program.range_constraints),
327327
copy.deepcopy(edge_program.equality_constraints),
328328
copy.deepcopy(edge_program.module_call_graph),
329+
None,
330+
edge_program.verifier,
329331
)

exir/capture/_capture.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
flatten_output,
2323
Value,
2424
)
25-
from executorch.exir.verification.verifier import EXIRATenDialectVerifier
25+
from executorch.exir.verification.verifier import EXIRATenDialectVerifierBase
2626
from torch import _guards
2727
from torch._dispatch.python import enable_python_dispatcher
2828
from torch._dynamo.eval_frame import Constraint
@@ -79,6 +79,16 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
7979
assert output_node.op == "output"
8080
user_outputs = [arg.name for arg in output_node.args[0]]
8181

82+
for n in graph_module.graph.nodes:
83+
if n.op == "call_function" and "val" not in n.meta:
84+
try:
85+
args, kwargs = pytree.tree_map_only(
86+
torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
87+
)
88+
n.meta["val"] = n.target(*args, **kwargs)
89+
except Exception:
90+
n.meta["val"] = None
91+
8292
ep = HackedUpExportedProgramDONOTUSE(
8393
graph_module,
8494
graph_module.graph,
@@ -112,6 +122,7 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
112122
)
113123
],
114124
None,
125+
EXIRATenDialectVerifierBase,
115126
)
116127
return ExirExportedProgram(ep, False)
117128

@@ -281,6 +292,7 @@ def convert_to_fake(x):
281292
for arg in output_node.args[0]
282293
]
283294

295+
graph_module.graph.eliminate_dead_code()
284296
ep = ExportedProgram(
285297
graph_module,
286298
graph_module.graph,
@@ -300,7 +312,7 @@ def convert_to_fake(x):
300312
)
301313
],
302314
None,
303-
EXIRATenDialectVerifier,
315+
EXIRATenDialectVerifierBase,
304316
)
305317
return ExirExportedProgram(ep, False)
306318

exir/emit/test/test_emit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,12 @@ def test_basic_end_to_end(self) -> None:
146146
self.assertIn(op.overload, {"out", "unary_out"})
147147

148148
self.assertEqual(ops[0].name, "aten::sin")
149-
self.assertEqual(ops[1].name, "aten::max")
150149

151150
self.assertEqual(len(exec_plan.inputs), 1)
152151
self.assertEqual(len(exec_plan.outputs), 1)
153152

154153
self.assertEqual(exec_plan.inputs[0], 0)
155-
self.assertEqual(exec_plan.outputs[0], 2)
154+
self.assertEqual(exec_plan.outputs[0], 1)
156155

157156
def test_nested_return(self) -> None:
158157
def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:

exir/lowered_backend_module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def program(self, emit_stacktrace: bool = False) -> Program:
210210
delegate_node.meta["spec"] = tuple(
211211
[make_spec(node.meta["val"]) for node in original_output_nodes]
212212
)
213+
delegate_node.meta["val"] = tuple(
214+
[node.meta["val"] for node in original_output_nodes]
215+
)
213216

214217
# The getitem nodes that are going to be inserted to the lowered graph module
215218
getitem_nodes = []
@@ -218,6 +221,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
218221
operator.getitem,
219222
args=(delegate_node, i),
220223
)
224+
getitem_node.meta["val"] = delegate_node.meta["val"][i]
221225
getitem_nodes.append(getitem_node)
222226
lowered_exported_program.graph.output(getitem_nodes)
223227

@@ -264,6 +268,8 @@ def program(self, emit_stacktrace: bool = False) -> Program:
264268
range_constraints=lowered_exported_program.range_constraints,
265269
equality_constraints=lowered_exported_program.equality_constraints,
266270
module_call_graph=lowered_exported_program.module_call_graph,
271+
example_inputs=None,
272+
verifier=lowered_exported_program.verifier,
267273
)
268274
exported_program = exported_program._transform(
269275
SpecPropPass(), MemoryPlanningPass("greedy")
@@ -466,6 +472,7 @@ def create_exported_program_from_submodule(
466472
range_constraints=copy.deepcopy(owning_program.range_constraints),
467473
equality_constraints=[],
468474
module_call_graph=[],
475+
verifier=owning_program.verifier,
469476
)
470477

471478

exir/program/_program.py

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from executorch.exir.verification.verifier import (
3333
EXIRATenDialectVerifier,
3434
EXIREdgeDialectVerifier,
35+
get_aten_verifier,
3536
)
3637
from torch._export import ExportedProgram
3738
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
@@ -44,6 +45,21 @@
4445
Val = Any
4546

4647

48+
def _copy_module(new_prog, new_gm):
49+
new_prog.meta.update(new_gm.meta)
50+
new_prog.graph = new_gm.graph
51+
submodules = [name for name, _ in new_prog.named_children()]
52+
for name in submodules:
53+
delattr(new_prog, name)
54+
for name, mod in new_gm.named_children():
55+
setattr(new_prog, name, mod)
56+
for node in new_gm.graph.nodes:
57+
if node.op == "get_attr":
58+
t = getattr(new_gm, node.target, None)
59+
if isinstance(t, torch.Tensor):
60+
setattr(new_prog, node.target, t)
61+
62+
4763
def lift_constant_tensor_pass(ep):
4864
"""
4965
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
@@ -131,6 +147,7 @@ def __init__(
131147
equality_constraints,
132148
module_call_graph,
133149
example_inputs,
150+
verifier,
134151
):
135152
super().__init__(
136153
root,
@@ -141,8 +158,8 @@ def __init__(
141158
equality_constraints,
142159
module_call_graph,
143160
example_inputs,
161+
verifier,
144162
)
145-
self._dialect = "HACKED_ATEN"
146163

147164
def __call__(self, *args: Any, **kwargs: Any) -> Any:
148165
import torch._export.error as error
@@ -245,9 +262,15 @@ def to_executorch(
245262
if not self.after_to_edge_passes:
246263
raise RuntimeError("Must run to_edge before to_executorch.")
247264
config = config or ExecutorchBackendConfig()
248-
ep = self.exported_program
249-
new_prog = ep._transform(*edge_to_executorch_passes(config))
250-
new_prog = ExirExportedProgram(new_prog, self.after_to_edge_passes)
265+
new_gm = self.exported_program.graph_module
266+
for p in edge_to_executorch_passes(config):
267+
new_gm_res = p(new_gm)
268+
assert new_gm_res is not None
269+
new_gm = new_gm_res.graph_module
270+
new_prog = ExirExportedProgram(
271+
copy.deepcopy(self.exported_program), self.after_to_edge_passes
272+
)
273+
_copy_module(new_prog.exported_program.graph_module, new_gm)
251274
executorch_prog = ExecutorchProgram(
252275
new_prog,
253276
emit_stacktrace=config.emit_stacktrace,
@@ -256,9 +279,7 @@ def to_executorch(
256279
constant_tensor_alignment=config.constant_tensor_alignment,
257280
delegate_alignment=config.delegate_alignment,
258281
)
259-
executorch_prog.graph_module.meta.update(
260-
new_prog.exported_program.graph_module.meta
261-
)
282+
executorch_prog.graph_module.meta.update(new_gm.meta)
262283
executorch_prog.graph_module.meta.update(
263284
self.exported_program.graph_module.meta
264285
)
@@ -356,6 +377,22 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
356377
)
357378
raise
358379

380+
dialect = ep.exported_program.dialect
381+
if dialect == "ATEN":
382+
ep = ExirExportedProgram(
383+
ExportedProgram(
384+
ep.exported_program.graph_module,
385+
ep.exported_program.graph_module.graph,
386+
ep.exported_program.graph_signature,
387+
ep.exported_program.state_dict,
388+
ep.exported_program.range_constraints,
389+
ep.exported_program.equality_constraints,
390+
ep.exported_program.module_call_graph,
391+
ep.exported_program.example_inputs,
392+
verifier=get_aten_verifier(enable=config._check_ir_validity),
393+
),
394+
False,
395+
)
359396
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
360397
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
361398
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
@@ -364,16 +401,23 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
364401
post_op_replace_passes = aten_to_edge_passes.passes[-2:]
365402

366403
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
367-
if new_ep.exported_program.dialect == "ATEN":
404+
if dialect == "ATEN":
368405
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
369406

407+
new_gm = new_ep.exported_program.graph_module
370408
if config._use_edge_ops:
371-
new_ep = new_ep.transform(OpReplacePass())
409+
new_gm_res = OpReplacePass()(new_gm)
410+
assert new_gm_res is not None
411+
new_gm = new_gm_res.graph_module
412+
413+
for p in post_op_replace_passes:
414+
new_gm_res = p(new_gm)
415+
assert new_gm_res is not None
416+
new_gm = new_gm_res.graph_module
372417

373-
new_ep = new_ep.transform(*post_op_replace_passes)
374418
new_ep.exported_program = ExportedProgram(
375-
new_ep.exported_program.graph_module,
376-
new_ep.exported_program.graph,
419+
new_gm,
420+
new_gm.graph,
377421
new_ep.exported_program.graph_signature,
378422
new_ep.exported_program.state_dict,
379423
new_ep.exported_program.range_constraints,
@@ -386,11 +430,6 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
386430
class_only=True,
387431
),
388432
)
389-
if config._check_ir_validity:
390-
# TODO(zhxchen17) Remove this call after we turn on verifier in ctor.
391-
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(
392-
new_ep.exported_program.graph_module
393-
)
394433
new_ep.after_to_edge_passes = True
395434
return new_ep
396435

@@ -623,8 +662,18 @@ def multi_method_program_to_executorch(
623662
) -> MultiMethodExecutorchProgram:
624663
config = config or ExecutorchBackendConfig()
625664
passes = edge_to_executorch_passes(config)
665+
res = {}
666+
for method_name, prog in edge_dialect_program._method_to_program.items():
667+
new_prog = copy.deepcopy(prog)
668+
gm = prog.exported_program.graph_module
669+
for p in passes:
670+
gm_res = p(gm)
671+
assert gm_res is not None
672+
gm = gm_res.graph_module
673+
_copy_module(new_prog.exported_program.graph_module, gm)
674+
res[method_name] = new_prog
626675
return MultiMethodExecutorchProgram(
627-
executorch_dialect_program=edge_dialect_program.transform(*passes),
676+
executorch_dialect_program=MultiMethodExirExportedProgram(res),
628677
emit_stacktrace=config.emit_stacktrace,
629678
extract_segments=config.extract_segments,
630679
segment_alignment=config.segment_alignment,
@@ -674,8 +723,6 @@ def to_edge(
674723
logging.info(f"Input program {name} is not in ATen dialect.")
675724
raise e
676725

677-
op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
678-
679726
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
680727
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
681728
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
@@ -686,9 +733,31 @@ def to_edge(
686733
ReplaceViewOpsWithViewCopyOpsPass()
687734
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
688735
passes.extend(aten_to_edge_passes.passes[:-2])
689-
passes.extend(op_replace_pass)
690-
passes.extend(aten_to_edge_passes.passes[-2:])
691736
edge_program = program._transform(*passes)
737+
if config._use_edge_ops:
738+
gm_res = OpReplacePass()(edge_program.graph_module)
739+
assert gm_res is not None
740+
gm = gm_res.graph_module
741+
else:
742+
gm = edge_program.graph_module
743+
edge_program = ExportedProgram(
744+
root=gm,
745+
graph=gm.graph,
746+
graph_signature=edge_program.graph_signature,
747+
state_dict=edge_program.state_dict,
748+
range_constraints=edge_program.range_constraints,
749+
equality_constraints=edge_program.equality_constraints,
750+
module_call_graph=edge_program.module_call_graph,
751+
example_inputs=edge_program.example_inputs,
752+
verifier=EXIREdgeDialectVerifier(
753+
check_edge_ops=config._use_edge_ops,
754+
enable=config._check_ir_validity,
755+
class_only=True,
756+
),
757+
)
758+
passes = []
759+
passes.extend(aten_to_edge_passes.passes[-2:])
760+
edge_program = edge_program._transform(*passes)
692761
try:
693762
EXIREdgeDialectVerifier(
694763
check_edge_ops=config._use_edge_ops, enable=config._check_ir_validity
@@ -852,7 +921,13 @@ def to_executorch(
852921

853922
execution_programs: Dict[str, ExportedProgram] = {}
854923
for name, program in self._edge_programs.items():
855-
new_prog = program._transform(*edge_to_executorch_passes(config))
924+
new_gm = program.graph_module
925+
for p in edge_to_executorch_passes(config):
926+
new_gm_res = p(new_gm)
927+
assert new_gm_res is not None
928+
new_gm = new_gm_res.graph_module
929+
new_prog = copy.deepcopy(program)
930+
_copy_module(new_prog.graph_module, new_gm)
856931
execution_programs[name] = new_prog
857932

858933
return ExecutorchProgramManager(

exir/serde/serialize.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from torch._export.serde.schema import GraphSignature
3939
from torch._export.serde.serialize import SerializeError
40+
from torch._export.verifier import load_verifier
4041
from torch.export.exported_program import (
4142
ExportGraphSignature,
4243
ModuleCallEntry,
@@ -367,7 +368,7 @@ def serialize(
367368
range_constraints=serialized_range_constraints,
368369
equality_constraints=serialized_equality_constraints,
369370
schema_version=schema.SCHEMA_VERSION,
370-
example_inputs=None,
371+
dialect=exported_program.dialect,
371372
),
372373
export_serialize.serialize_torch_artifact(gm_serializer.state_dict),
373374
)
@@ -686,15 +687,23 @@ def deserialize(
686687
serialized_exported_program.equality_constraints
687688
)
688689

689-
return exir.ExportedProgram(
690+
dummy_g = torch.fx.Graph()
691+
dummy_g.output(())
692+
ep = exir.ExportedProgram(
690693
state_dict,
691-
graph_module.graph,
694+
dummy_g,
692695
sig,
693696
{}, # TODO(T157676982)
694697
range_constraints,
695698
equality_constraints,
696699
module_call_graph,
700+
None,
701+
load_verifier(serialized_exported_program.dialect),
697702
)
703+
ep.graph_module.graph = graph_module.graph
704+
for name, t in state_dict.items():
705+
setattr(ep.graph_module, name, t)
706+
return ep
698707

699708

700709
def serialize(

exir/tests/control_flow_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(self):
6565

6666
def forward(self, inp):
6767
def true_branch(x):
68-
return torch.inverse(x).contiguous()
68+
x - 1
69+
return x + 1
6970

7071
def false_branch(x):
7172
return x * 2

0 commit comments

Comments
 (0)