Skip to content

Commit 31c276d

Browse files
authored
CoreML partitioner improvements (#12532)
This creates several improvements for the CoreML partitioner: * We now log when nodes are skipped for partitioning and the reason they are skipped. * We add new partitioner option lower_full_graph that raises Exception if the model cannot be fully delegated. In future, when this option is enabled we plan to support additional CoreML features like enumerated shapes. * We add take_over_constant_data option to tell CoreML delegate to not consume weight data.
1 parent 8d0dbc2 commit 31c276d

File tree

2 files changed

+226
-11
lines changed

2 files changed

+226
-11
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@
2828

2929
class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
3030
def __init__(
31-
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
31+
self,
32+
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
33+
lower_full_graph: bool = False,
3234
) -> None:
3335
if skip_ops_for_coreml_delegation is None:
3436
skip_ops_for_coreml_delegation = []
3537
super().__init__()
3638
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
39+
self.lower_full_graph = lower_full_graph
40+
self._logged_msgs = set()
41+
42+
def log_once(self, msg: str) -> None:
43+
if msg not in self._logged_msgs:
44+
logging.info(msg)
45+
self._logged_msgs.add(msg)
3746

3847
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3948
# get_attr node can always be supported on any backend
@@ -44,14 +53,63 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4453
# skip ops if specified by user
4554
node_target_name = getattr(node.target, "__name__", "").lower()
4655
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
56+
self.log_once(
57+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
58+
+ node_target_name
59+
)
60+
assert (
61+
not self.lower_full_graph
62+
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
4763
return False
64+
65+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
66+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
67+
# # in the placeholders due to partitioning, which CoreML does not support
68+
# if not self.lower_full_graph and any(
69+
# isinstance(arg, torch.fx.Node)
70+
# and isinstance(
71+
# arg.meta.get("val", None),
72+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
73+
# )
74+
# for arg in node.args
75+
# ):
76+
# self.log_once(
77+
# "Skipping op for CoreML delegation because it contains symbolic args: "
78+
# + node_target_name
79+
# )
80+
# assert not self.lower_full_graph
81+
# return False
82+
4883
# query coremltools to see if node is supported
49-
return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
84+
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
85+
node
86+
)
87+
if not is_supported:
88+
if self.lower_full_graph:
89+
raise NotImplementedError(
90+
f"""CoreML does not support the op {node_target_name}, but you have set lower_full_graph=True in the CoreMLPartitioner.
91+
92+
Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features.
93+
As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op.
94+
95+
Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues
96+
Do not file an issue with ExecuTorch for op support.
97+
"""
98+
)
99+
self.log_once(
100+
"Skipping op for CoreML delegation because it is not supported by CoreML: "
101+
+ node_target_name
102+
)
103+
return is_supported
50104
# cowardly refuse to support all other types of node:
51105
# 1. placeholder / output nodes should not be tagged
52106
# reference: https://github.com/pytorch/executorch/pull/1398
53107
# 2. call_module / call_method should have been replaced with call_function?
54108
else:
109+
self.log_once(
110+
"Skipping op for CoreML delegation because it is not get_attr or call_function: "
111+
+ node.op
112+
)
55113
return False
56114

57115

@@ -62,6 +120,8 @@ def __init__(
62120
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
63121
compile_specs: Optional[List[CompileSpec]] = None,
64122
take_over_mutable_buffer: Optional[bool] = True,
123+
lower_full_graph: bool = False,
124+
take_over_constant_data: bool = True,
65125
) -> None:
66126
if skip_ops_for_coreml_delegation is None:
67127
skip_ops_for_coreml_delegation = []
@@ -71,6 +131,20 @@ def __init__(
71131
compile_specs=compile_specs if compile_specs is not None else [],
72132
)
73133
self.take_over_mutable_buffer = take_over_mutable_buffer
134+
self.lower_full_graph = lower_full_graph
135+
self.take_over_constant_data = take_over_constant_data
136+
self._logged_msgs = set()
137+
138+
if self.lower_full_graph:
139+
assert (
140+
len(self.skip_ops_for_coreml_delegation) == 0
141+
), "When lower_full_graph=True, you cannot set skip_ops_for_coreml_delegation"
142+
assert (
143+
self.take_over_constant_data
144+
), "When lower_full_graph=True, you must set take_over_constant_data=True"
145+
assert (
146+
self.take_over_mutable_buffer
147+
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"
74148

75149
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
76150
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -80,7 +154,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
80154

81155
capability_partitioner = CapabilityBasedPartitioner(
82156
exported_program.graph_module,
83-
OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation),
157+
OperatorsSupportedForCoreMLBackend(
158+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
159+
),
84160
allows_single_node_partition=True,
85161
)
86162
partition_list = capability_partitioner.propose_partitions()
@@ -90,7 +166,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
90166
node.meta["delegation_tag"] = tag
91167
partition_tags[tag] = self.delegation_spec
92168

93-
tag_constant_data(exported_program)
169+
if self.take_over_constant_data:
170+
tag_constant_data(exported_program)
94171
if self.take_over_mutable_buffer:
95172
logger.info(
96173
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
@@ -105,12 +182,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
105182
tagged_exported_program=exported_program, partition_tags=partition_tags
106183
)
107184

185+
def log_once(self, msg: str) -> None:
186+
if msg not in self._logged_msgs:
187+
logging.info(msg)
188+
self._logged_msgs.add(msg)
189+
108190
def ops_to_not_decompose(
109191
self, ep: ExportedProgram
110192
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111193
do_not_decompose = []
112-
op_support = OperatorsSupportedForCoreMLBackend()
113-
_logged_warnings = set()
194+
op_support = OperatorsSupportedForCoreMLBackend(
195+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
196+
)
114197

115198
# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
116199
# TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the
@@ -134,9 +217,7 @@ def ops_to_not_decompose(
134217
except Exception as e:
135218
# CoreML's op_support.is_node_supported will sometimes throw
136219
# for unsupported ops, rather than returning False
137-
warn_str = f"Encountered exception when checking node support: {e}"
138-
if warn_str not in _logged_warnings:
139-
logger.warning(warn_str)
140-
_logged_warnings.add(warn_str)
141-
220+
self.log_once(
221+
f"Encountered exception when checking node support, treating node as unsupported: {e}"
222+
)
142223
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

5+
import copy
6+
import sys
57
import unittest
68

79
import coremltools as ct
@@ -14,6 +16,28 @@
1416
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1517
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1618
from executorch.exir.backend.utils import format_delegated_graph
19+
from executorch.runtime import Runtime
20+
21+
22+
@torch.library.custom_op("unsupported::linear", mutates_args=())
23+
def _(
24+
x: torch.Tensor,
25+
w: torch.Tensor,
26+
b: torch.Tensor,
27+
) -> torch.Tensor:
28+
return torch.ops.aten.linear.default(x, w, b)
29+
30+
31+
@torch.library.register_fake("unsupported::linear")
32+
def _(
33+
x: torch.Tensor,
34+
w: torch.Tensor,
35+
b: torch.Tensor,
36+
) -> torch.Tensor:
37+
return torch.ops.aten.linear.default(x, w, b)
38+
39+
40+
_TEST_RUNTIME = sys.platform == "darwin"
1741

1842

1943
class TestCoreMLPartitioner(unittest.TestCase):
@@ -200,10 +224,120 @@ def forward(self, q, k_val, input_pos):
200224
"getitem",
201225
]
202226

227+
def test_lower_full_graph(self):
228+
class Model(torch.nn.Module):
229+
def forward(self, a, x, b):
230+
out = torch.ops.aten.linear.default(a, x, b)
231+
out2 = torch.ops.unsupported.linear.default(out, x, b)
232+
return out2
233+
234+
model = Model()
235+
model.eval()
236+
237+
example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
238+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
239+
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
240+
edge_program_manager2 = copy.deepcopy(edge_program_manager)
241+
242+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
243+
244+
for node in delegated_program_manager.exported_program().graph.nodes:
245+
if node.op == "call_function":
246+
assert node.target.__name__ in [
247+
"unsupported.linear.default",
248+
"executorch_call_delegate",
249+
"getitem",
250+
], node.target.__name__
251+
252+
with self.assertRaises(NotImplementedError):
253+
edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True))
254+
255+
# TODO: enable this after bugs are fixed in ExecuTorch's partitioner
256+
# def test_symint_arg(self):
257+
# class Model(torch.nn.Module):
258+
# def forward(self, x, w, b, y):
259+
# val = y.item()
260+
# torch._check(val >= 0)
261+
# torch._check(val < 2)
262+
# out = torch.ops.aten.linear.default(x, w, b)
263+
# out2 = out.relu()[val]
264+
# return out2
265+
266+
# model = Model()
267+
# model.eval()
268+
# example_inputs = (
269+
# torch.randn(2, 2),
270+
# torch.randn(2, 2),
271+
# torch.randn(2, 2),
272+
# torch.tensor(2),
273+
# )
274+
# exir_program_aten = torch.export.export(model, example_inputs)
275+
276+
# edge_program_manager = executorch.exir.to_edge(exir_program_aten)
277+
278+
# delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner(skip_ops_for_coreml_delegation=["aten.scalar_tensor.default"]))
279+
280+
# # This op has symbolic args
281+
# assert (
282+
# "torch.ops.aten._assert_scalar.default"
283+
# in delegated_program_manager.exported_program().graph_module.code
284+
# )
285+
286+
# if _TEST_RUNTIME:
287+
# et_prog = delegated_program_manager.to_executorch()
288+
# runtime = Runtime.get()
289+
# program = runtime.load_program(et_prog.buffer)
290+
# method = program.load_method("forward")
291+
# et_outputs = method.execute(*example_inputs)[0]
292+
# eager_outputs = model(*example_inputs)
293+
# self.assertTrue(torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02))
294+
295+
def test_take_over_constant_data_false(self):
296+
class Model(torch.nn.Module):
297+
def __init__(self):
298+
super().__init__()
299+
self.linear = torch.nn.Linear(50, 100)
300+
301+
def forward(self, x):
302+
return self.linear(x)
303+
304+
model = Model()
305+
model.eval()
306+
example_inputs = (torch.randn(2, 50),)
307+
exir_program_aten = torch.export.export(model, example_inputs)
308+
309+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
310+
exir_program_aten,
311+
partitioner=[CoreMLPartitioner(take_over_constant_data=False)],
312+
)
313+
for node in edge_program_manager.exported_program().graph.nodes:
314+
if (
315+
node.op == "call_function"
316+
and node.target.__name__ == "executorch_call_delegate"
317+
):
318+
break
319+
320+
# lowered_module_0, x, p_linear_weight, p_linear_bias
321+
assert len(node.args) == 4
322+
323+
if _TEST_RUNTIME:
324+
et_prog = edge_program_manager.to_executorch()
325+
runtime = Runtime.get()
326+
program = runtime.load_program(et_prog.buffer)
327+
method = program.load_method("forward")
328+
et_outputs = method.execute(*example_inputs)[0]
329+
eager_outputs = model(*example_inputs)
330+
self.assertTrue(
331+
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
332+
)
333+
203334

204335
if __name__ == "__main__":
205336
test_runner = TestCoreMLPartitioner()
206337
test_runner.test_add_sub_skip_mm()
207338
test_runner.test_vit_skip_conv()
208339
test_runner.test_ops_to_not_decompose()
209340
test_runner.test_buffer()
341+
test_runner.test_lower_full_graph()
342+
# test_runner.test_symint_arg()
343+
test_runner.test_take_over_constant_data_false()

0 commit comments

Comments
 (0)