Skip to content

Commit 0012ffa

Browse files
authored
Add preserve_ops to EdgeCompileConfig (#12546)
1. Add `preserve_ops` to `EdgeCompileConfig` 2. Remove preserved ops from the decomposition table in `to_edge`. 3. Add checks to the verifier ensuring that preserved ops do not have mutations or views. 4. Update 'core_aten_exception_list' to be 'preserved_ops' in `to_edge_transform_and_lower`. Context/Usage **core_aten_ops_exception_list** - Contains operators that are missing a decomposition to core aten. - Exclude these so that verification can still be run on the rest of the graph. - Ideally, this list should be empty. **preserve_ops** - Contains operators that the user specifically does not want decomposed. - Must be aten; custom ops are ignored by verifier. Edge case: - If an aten operator does not have a decomp, and the user specifically wants it to be preserved, put it in preserve_ops rather than core_aten_ops_exception_list. Differential Revision: [D78298749](https://our.internmc.facebook.com/intern/diff/D78298749/)
1 parent 80da097 commit 0012ffa

File tree

6 files changed

+121
-38
lines changed

6 files changed

+121
-38
lines changed

backends/nxp/nxp_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def preprocess(
174174
# Otherwise, we get violation that this op is not part of ATen Core ops.
175175
edge_program._verifiers = [
176176
EXIREdgeDialectVerifier(
177-
class_only=True, exception_list=[torch.ops.aten.max_pool2d.default]
177+
class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default]
178178
)
179179
]
180180

exir/capture/_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ class EdgeCompileConfig:
4040
# TODO(larryliu): remove this
4141
_use_edge_ops: bool = True
4242
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
43+
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
44+
# enabling verification on the rest of the program until decomposition coverage is improved.
4345
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
4446
default_factory=list
4547
)
48+
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
49+
# These may be core or non-core ATen ops; custom ops should not be here.
50+
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
4651
# TODO(gasoonjia): remove this
4752
_skip_dim_order: bool = False
4853

exir/program/_program.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,19 @@ def _generate_edge_program(
795795
name: str,
796796
config: EdgeCompileConfig,
797797
program: ExportedProgram,
798-
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
798+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
799+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799800
) -> ExportedProgram:
800-
801+
"""
802+
Args:
803+
name: The name of the program.
804+
config: The configuration for the edge program.
805+
program: The exported program to be converted to an edge program.
806+
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
807+
preserve_ops: A list of aten ops that should not be decomposed.
808+
Returns:
809+
An ExportedProgram in edge dialect.
810+
"""
801811
# Remove invalid assert ops, such as _assert_tensor_metadata
802812
gm = program.graph_module
803813
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
@@ -812,7 +822,8 @@ def _generate_edge_program(
812822
EXIRATenDialectVerifier(
813823
edge_compile_config=config,
814824
class_only=False,
815-
exception_list=ops_set_to_not_decompose,
825+
core_aten_ops_exception_list=core_aten_ops_exception_list,
826+
preserve_ops=preserve_ops,
816827
)(gm)
817828
except ExportError as e:
818829
logging.info(f"Input program {name} is not in ATen dialect.")
@@ -848,7 +859,8 @@ def _generate_edge_program(
848859
EXIREdgeDialectVerifier(
849860
edge_compile_config=config,
850861
class_only=True,
851-
exception_list=ops_set_to_not_decompose,
862+
core_aten_ops_exception_list=core_aten_ops_exception_list,
863+
preserve_ops=preserve_ops,
852864
)
853865
],
854866
)
@@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops(
864876
program: ExportedProgram,
865877
partitioner,
866878
):
867-
ops_to_not_decompose = set()
879+
preserve_ops = set()
868880
partitioners = partitioner.get(name)
869881
if partitioners is None:
870882
return
@@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops(
889901
and node.target in ops_set_to_not_decompose
890902
and is_op_supported
891903
):
892-
ops_to_not_decompose.add(node.target)
904+
preserve_ops.add(node.target)
893905
node.target = aten_op_to_transform_op[node.target]
894906

895907
for _, submod, _ in get_control_flow_submodules(program.graph_module):
@@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops(
900912
and node.target in ops_set_to_not_decompose
901913
and is_op_supported
902914
):
903-
ops_to_not_decompose.add(node.target)
915+
preserve_ops.add(node.target)
904916
node.target = aten_op_to_transform_op[node.target]
905917

906-
return ops_to_not_decompose
918+
return preserve_ops
907919

908920

909921
def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
@@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops(
10141026

10151027

10161028
def _remove_invalid_ops_for_not_decompose(
1017-
ops_to_not_decompose: List[torch._ops.OpOverload],
1029+
preserve_ops: List[torch._ops.OpOverload],
10181030
) -> List[torch._ops.OpOverload]:
10191031
_logged_warnings = set()
10201032

@@ -1079,7 +1091,7 @@ def keep(op):
10791091
return False
10801092
return True
10811093

1082-
return list(filter(keep, ops_to_not_decompose))
1094+
return list(filter(keep, preserve_ops))
10831095

10841096

10851097
def _gen_edge_manager_for_partitioners(
@@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners(
11361148
name,
11371149
config,
11381150
program,
1139-
list(ops_set_to_not_decompose_by_program.get(name, [])),
1151+
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
11401152
)
11411153

11421154
edge_manager = EdgeProgramManager(
@@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower(
12811293
EXIREdgeDialectVerifier(
12821294
edge_compile_config=config,
12831295
class_only=True,
1284-
exception_list=list(ops_set_to_not_decompose),
1296+
preserve_ops=list(ops_set_to_not_decompose),
12851297
)()(program.graph_module)
12861298

12871299
return edge_manager
@@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops(
13281340
table.pop(op, None)
13291341
program = program.run_decompositions(table)
13301342
edge_programs[name] = _generate_edge_program(
1331-
name, config, program, list(preserve_ops)
1343+
name, config, program, preserve_ops=list(preserve_ops)
13321344
)
13331345

13341346
return EdgeProgramManager(
@@ -1367,8 +1379,16 @@ def to_edge(
13671379

13681380
for name, program in aten_programs.items():
13691381
# Decompose to Core ATen
1370-
program = program.run_decompositions(_default_decomposition_table())
1371-
edge_programs[name] = _generate_edge_program(name, config, program)
1382+
table = _default_decomposition_table()
1383+
preserve_ops = []
1384+
if compile_config:
1385+
preserve_ops = compile_config._preserve_ops
1386+
for op in compile_config._preserve_ops:
1387+
table.pop(op, None)
1388+
program = program.run_decompositions(table)
1389+
edge_programs[name] = _generate_edge_program(
1390+
name, config, program, preserve_ops=preserve_ops
1391+
)
13721392

13731393
return EdgeProgramManager(edge_programs, constant_methods, config)
13741394

@@ -1389,7 +1409,8 @@ def __init__(
13891409
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
13901410
constant_methods: Optional[Dict[str, Any]] = None,
13911411
compile_config: Optional[EdgeCompileConfig] = None,
1392-
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
1412+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
1413+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
13931414
):
13941415
"""
13951416
Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1404,7 +1425,8 @@ def __init__(
14041425
try:
14051426
EXIREdgeDialectVerifier(
14061427
edge_compile_config=self.compile_config,
1407-
exception_list=ops_set_to_not_decompose,
1428+
core_aten_ops_exception_list=core_aten_ops_exception_list,
1429+
preserve_ops=preserve_ops,
14081430
)(program.graph_module)
14091431
except ExportError as e:
14101432
logging.info(f"Input program {name} is not in aten dialect.")

exir/program/test/test_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
ExecutorchProgramManager,
2828
to_edge,
2929
to_edge_transform_and_lower,
30-
to_edge_with_preserved_ops,
3130
)
3231
from executorch.exir.tracer import _default_decomposition_table
3332
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
@@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
784783
def _test_to_edge_with_preserved_ops(
785784
self, program, preserved_ops, expected_preserved_ops
786785
):
787-
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
786+
edge = to_edge(
787+
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
788+
)
788789

789790
def count_nodes(graph_module, target):
790791
count = 0

exir/verification/test/test_verifier.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,17 @@ def forward(self, input, label):
161161
edge_verifier = EXIREdgeDialectVerifier()
162162

163163
edge_verifier(edge.exported_program())
164+
165+
def test_verifier_preserve_ops_view(self) -> None:
166+
class TestExpand(nn.Module):
167+
def __init__(self):
168+
super().__init__()
169+
170+
def forward(self, x):
171+
return x.expand(2, 2, 2, 2)
172+
173+
model = TestExpand()
174+
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
175+
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
176+
with self.assertRaises(RuntimeError):
177+
to_edge(export_model, compile_config=config)

exir/verification/verifier.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
#
7+
# pyre-unsafe
68

79
import itertools
10+
import logging
811
import operator
912
import types
1013
from contextlib import nullcontext
@@ -81,26 +84,33 @@ def __call__(self, *args, **kwargs):
8184
def EXIRATenDialectVerifier( # noqa: C901
8285
edge_compile_config: Optional[EdgeCompileConfig] = None,
8386
class_only: bool = False,
84-
exception_list: Optional[List[torch._ops.OpOverload]] = None,
87+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
88+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
8589
):
8690
"""
8791
Returns a verifier class that runs ATen dialect specific checks on the graph module.
8892
"""
93+
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
94+
_preserve_ops = preserve_ops or []
8995
# merge the exception list from edge_compile_config and exception_list
90-
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
91-
exception_list = edge_compile_config._core_aten_ops_exception_list + (
92-
exception_list or []
93-
)
96+
if edge_compile_config:
97+
if edge_compile_config._core_aten_ops_exception_list:
98+
_core_aten_ops_exception_list.extend(
99+
edge_compile_config._core_aten_ops_exception_list
100+
)
101+
if edge_compile_config._preserve_ops:
102+
_preserve_ops.extend(edge_compile_config._preserve_ops)
94103

95104
class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
96105
dialect = "OLD_EXIR_ATEN"
97106

98107
def __init__(self) -> None:
99108
super().__init__()
100109
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
101-
self._exception_list = exception_list if exception_list else []
110+
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
111+
self._preserve_ops = _preserve_ops
102112

103-
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
113+
def _get_core_aten_ops_exception_list(self) -> List[torch._ops.OpOverload]:
104114
exception_list = (
105115
[
106116
torch.ops.aten.mkldnn_rnn_layer.default,
@@ -113,15 +123,35 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
113123
]
114124
+ list(_EXECUTORCH_SYM_OPS)
115125
+ DISALLOW_LIST
116-
+ self._exception_list
126+
+ self._core_aten_ops_exception_list
117127
)
118128

119129
return exception_list
120130

121131
def check_valid_op(self, op):
122132
if isinstance(op, OpOverload):
123133
# TODO These special ops should be removable easily.
124-
if op.namespace != "aten" or op in self._get_exception_list():
134+
if (
135+
op.namespace != "aten"
136+
or op in self._get_core_aten_ops_exception_list()
137+
):
138+
return
139+
if op in self._preserve_ops:
140+
if op.namespace != "aten":
141+
raise RuntimeError(
142+
f"Only preserve aten ops. Received op {op} with namespace {op.namespace}."
143+
)
144+
# Preserved ops should not include mutation or view,
145+
# which may affect memory planning.
146+
if op.is_view:
147+
raise RuntimeError(
148+
f"Cannot preserve operator {op} because it is a view or mutation."
149+
)
150+
if op._schema.is_mutable:
151+
logging.warning(
152+
f"Preserving mutation ops like {op} is a no-op because run_decomposition functionalizes it and prevents it from showing up."
153+
)
154+
125155
return
126156
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
127157
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
@@ -149,7 +179,9 @@ def check_valid_op(self, op):
149179
def get_aten_verifier(config: EdgeCompileConfig):
150180
return (
151181
EXIRATenDialectVerifier(
152-
class_only=True, exception_list=config._core_aten_ops_exception_list
182+
class_only=True,
183+
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
184+
preserve_ops=config._preserve_ops,
153185
)
154186
if config._check_ir_validity
155187
else EXIRATenDialectVerifierBase
@@ -210,13 +242,19 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
210242
def EXIREdgeDialectVerifier( # noqa: C901
211243
edge_compile_config: Optional[EdgeCompileConfig] = None,
212244
class_only: bool = False,
213-
exception_list: Optional[List[torch._ops.OpOverload]] = None,
245+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
246+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
214247
):
248+
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
249+
_preserve_ops = preserve_ops or []
215250
# merge the exception list from edge_compile_config and exception_list
216-
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
217-
exception_list = edge_compile_config._core_aten_ops_exception_list + (
218-
exception_list or []
219-
)
251+
if edge_compile_config:
252+
if edge_compile_config._core_aten_ops_exception_list:
253+
_core_aten_ops_exception_list.extend(
254+
edge_compile_config._core_aten_ops_exception_list
255+
)
256+
if edge_compile_config._preserve_ops:
257+
_preserve_ops.extend(edge_compile_config._preserve_ops)
220258

221259
class _EXIREdgeDialectVerifier(Verifier):
222260
dialect = "EDGE"
@@ -228,16 +266,19 @@ def __init__(self) -> None:
228266
self.check_edge_ops = _edge_compile_config._use_edge_ops
229267
self.use_dim_order = not _edge_compile_config._skip_dim_order
230268

269+
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
270+
self._preserve_ops = _preserve_ops
271+
231272
self.aten_op_verifier = EXIRATenDialectVerifier(
232-
exception_list=exception_list
273+
core_aten_ops_exception_list=_core_aten_ops_exception_list,
274+
preserve_ops=_preserve_ops,
233275
)
234276
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
235277

236278
if self.check_edge_ops:
237279
self.check_valid_op = self.check_valid_edge_op
238280
else:
239281
self.check_valid_op = self.check_valid_aten_op
240-
self._exception_list = exception_list if exception_list else []
241282

242283
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
243284
return (
@@ -258,7 +299,7 @@ def check_valid_edge_op(self, op):
258299
in [operator.getitem]
259300
+ DISALLOW_LIST
260301
+ list(_EXECUTORCH_SYM_OPS)
261-
+ self._exception_list
302+
+ self._core_aten_ops_exception_list
262303
):
263304
return
264305

0 commit comments

Comments
 (0)