Skip to content

CoreML partitioner improvements #12532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 92 additions & 11 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,21 @@

class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
def __init__(
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
self,
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
lower_full_graph: bool = False,
) -> None:
if skip_ops_for_coreml_delegation is None:
skip_ops_for_coreml_delegation = []
super().__init__()
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
self.lower_full_graph = lower_full_graph
self._logged_msgs = set()

def log_once(self, msg: str) -> None:
if msg not in self._logged_msgs:
logging.info(msg)
self._logged_msgs.add(msg)

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# get_attr node can always be supported on any backend
Expand All @@ -44,14 +53,63 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# skip ops if specified by user
node_target_name = getattr(node.target, "__name__", "").lower()
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
self.log_once(
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
+ node_target_name
)
assert (
not self.lower_full_graph
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
return False

# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
# # in the placeholders due to partitioning, which CoreML does not support
# if not self.lower_full_graph and any(
# isinstance(arg, torch.fx.Node)
# and isinstance(
# arg.meta.get("val", None),
# (torch.SymInt, torch.SymBool, torch.SymFloat),
# )
# for arg in node.args
# ):
# self.log_once(
# "Skipping op for CoreML delegation because it contains symbolic args: "
# + node_target_name
# )
# assert not self.lower_full_graph
# return False

# query coremltools to see if node is supported
return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
node
)
if not is_supported:
if self.lower_full_graph:
raise NotImplementedError(
f"""CoreML does not support the op {node_target_name}, but you have set lower_full_graph=True in the CoreMLPartitioner.

Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features.
As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op.

Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues
Do not file an issue with ExecuTorch for op support.
"""
)
self.log_once(
"Skipping op for CoreML delegation because it is not supported by CoreML: "
+ node_target_name
)
return is_supported
# cowardly refuse to support all other types of node:
# 1. placeholder / output nodes should not be tagged
# reference: https://github.com/pytorch/executorch/pull/1398
# 2. call_module / call_method should have been replaced with call_function?
else:
self.log_once(
"Skipping op for CoreML delegation because it is not get_attr or call_function: "
+ node.op
)
return False


Expand All @@ -62,6 +120,8 @@ def __init__(
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
compile_specs: Optional[List[CompileSpec]] = None,
take_over_mutable_buffer: Optional[bool] = True,
lower_full_graph: bool = False,
take_over_constant_data: bool = True,
) -> None:
if skip_ops_for_coreml_delegation is None:
skip_ops_for_coreml_delegation = []
Expand All @@ -71,6 +131,20 @@ def __init__(
compile_specs=compile_specs if compile_specs is not None else [],
)
self.take_over_mutable_buffer = take_over_mutable_buffer
self.lower_full_graph = lower_full_graph
self.take_over_constant_data = take_over_constant_data
self._logged_msgs = set()

if self.lower_full_graph:
assert (
len(self.skip_ops_for_coreml_delegation) == 0
), "When lower_full_graph=True, you cannot set skip_ops_for_coreml_delegation"
assert (
self.take_over_constant_data
), "When lower_full_graph=True, you must set take_over_constant_data=True"
assert (
self.take_over_mutable_buffer
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
Expand All @@ -80,7 +154,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation),
OperatorsSupportedForCoreMLBackend(
self.skip_ops_for_coreml_delegation, self.lower_full_graph
),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
Expand All @@ -90,7 +166,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
if self.take_over_constant_data:
tag_constant_data(exported_program)
if self.take_over_mutable_buffer:
logger.info(
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
Expand All @@ -105,12 +182,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def log_once(self, msg: str) -> None:
if msg not in self._logged_msgs:
logging.info(msg)
self._logged_msgs.add(msg)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
do_not_decompose = []
op_support = OperatorsSupportedForCoreMLBackend()
_logged_warnings = set()
op_support = OperatorsSupportedForCoreMLBackend(
self.skip_ops_for_coreml_delegation, self.lower_full_graph
)

# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
# TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the
Expand All @@ -134,9 +217,7 @@ def ops_to_not_decompose(
except Exception as e:
# CoreML's op_support.is_node_supported will sometimes throw
# for unsupported ops, rather than returning False
warn_str = f"Encountered exception when checking node support: {e}"
if warn_str not in _logged_warnings:
logger.warning(warn_str)
_logged_warnings.add(warn_str)

self.log_once(
f"Encountered exception when checking node support, treating node as unsupported: {e}"
)
return do_not_decompose, None
134 changes: 134 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import copy
import sys
import unittest

import coremltools as ct
Expand All @@ -14,6 +16,28 @@
from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir.backend.utils import format_delegated_graph
from executorch.runtime import Runtime


@torch.library.custom_op("unsupported::linear", mutates_args=())
def _(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return torch.ops.aten.linear.default(x, w, b)


@torch.library.register_fake("unsupported::linear")
def _(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return torch.ops.aten.linear.default(x, w, b)


_TEST_RUNTIME = sys.platform == "darwin"


class TestCoreMLPartitioner(unittest.TestCase):
Expand Down Expand Up @@ -200,10 +224,120 @@ def forward(self, q, k_val, input_pos):
"getitem",
]

def test_lower_full_graph(self):
class Model(torch.nn.Module):
def forward(self, a, x, b):
out = torch.ops.aten.linear.default(a, x, b)
out2 = torch.ops.unsupported.linear.default(out, x, b)
return out2

model = Model()
model.eval()

example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
edge_program_manager2 = copy.deepcopy(edge_program_manager)

delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())

for node in delegated_program_manager.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"unsupported.linear.default",
"executorch_call_delegate",
"getitem",
], node.target.__name__

with self.assertRaises(NotImplementedError):
edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True))

# TODO: enable this after bugs are fixed in ExecuTorch's partitioner
# def test_symint_arg(self):
# class Model(torch.nn.Module):
# def forward(self, x, w, b, y):
# val = y.item()
# torch._check(val >= 0)
# torch._check(val < 2)
# out = torch.ops.aten.linear.default(x, w, b)
# out2 = out.relu()[val]
# return out2

# model = Model()
# model.eval()
# example_inputs = (
# torch.randn(2, 2),
# torch.randn(2, 2),
# torch.randn(2, 2),
# torch.tensor(2),
# )
# exir_program_aten = torch.export.export(model, example_inputs)

# edge_program_manager = executorch.exir.to_edge(exir_program_aten)

# delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner(skip_ops_for_coreml_delegation=["aten.scalar_tensor.default"]))

# # This op has symbolic args
# assert (
# "torch.ops.aten._assert_scalar.default"
# in delegated_program_manager.exported_program().graph_module.code
# )

# if _TEST_RUNTIME:
# et_prog = delegated_program_manager.to_executorch()
# runtime = Runtime.get()
# program = runtime.load_program(et_prog.buffer)
# method = program.load_method("forward")
# et_outputs = method.execute(*example_inputs)[0]
# eager_outputs = model(*example_inputs)
# self.assertTrue(torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02))

def test_take_over_constant_data_false(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(50, 100)

def forward(self, x):
return self.linear(x)

model = Model()
model.eval()
example_inputs = (torch.randn(2, 50),)
exir_program_aten = torch.export.export(model, example_inputs)

edge_program_manager = executorch.exir.to_edge_transform_and_lower(
exir_program_aten,
partitioner=[CoreMLPartitioner(take_over_constant_data=False)],
)
for node in edge_program_manager.exported_program().graph.nodes:
if (
node.op == "call_function"
and node.target.__name__ == "executorch_call_delegate"
):
break

# lowered_module_0, x, p_linear_weight, p_linear_bias
assert len(node.args) == 4

if _TEST_RUNTIME:
et_prog = edge_program_manager.to_executorch()
runtime = Runtime.get()
program = runtime.load_program(et_prog.buffer)
method = program.load_method("forward")
et_outputs = method.execute(*example_inputs)[0]
eager_outputs = model(*example_inputs)
self.assertTrue(
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)


if __name__ == "__main__":
test_runner = TestCoreMLPartitioner()
test_runner.test_add_sub_skip_mm()
test_runner.test_vit_skip_conv()
test_runner.test_ops_to_not_decompose()
test_runner.test_buffer()
test_runner.test_lower_full_graph()
# test_runner.test_symint_arg()
test_runner.test_take_over_constant_data_false()
Loading