diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 2876568b2fe..dc9797aa596 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -28,12 +28,21 @@ class OperatorsSupportedForCoreMLBackend(OperatorSupportBase): def __init__( - self, skip_ops_for_coreml_delegation: Optional[List[str]] = None + self, + skip_ops_for_coreml_delegation: Optional[List[str]] = None, + lower_full_graph: bool = False, ) -> None: if skip_ops_for_coreml_delegation is None: skip_ops_for_coreml_delegation = [] super().__init__() self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation + self.lower_full_graph = lower_full_graph + self._logged_msgs = set() + + def log_once(self, msg: str) -> None: + if msg not in self._logged_msgs: + logging.info(msg) + self._logged_msgs.add(msg) def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # get_attr node can always be supported on any backend @@ -44,14 +53,63 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # skip ops if specified by user node_target_name = getattr(node.target, "__name__", "").lower() if node_target_name in (self.skip_ops_for_coreml_delegation or []): + self.log_once( + "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: " + + node_target_name + ) + assert ( + not self.lower_full_graph + ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True" return False + + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed + # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args + # # in the placeholders due to partitioning, which CoreML does not support + # if not self.lower_full_graph and any( + # isinstance(arg, torch.fx.Node) + # and isinstance( + # arg.meta.get("val", None), + # (torch.SymInt, torch.SymBool, torch.SymFloat), + # ) + # for arg in node.args + # ): + # self.log_once( + # "Skipping op for CoreML delegation because it contains symbolic args: " + # + node_target_name + # ) + # assert not self.lower_full_graph + # return False + # query coremltools to see if node is supported - return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) + is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported( + node + ) + if not is_supported: + if self.lower_full_graph: + raise NotImplementedError( + f"""CoreML does not support the op {node_target_name}, but you have set lower_full_graph=True in the CoreMLPartitioner. + +Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features. +As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op. + +Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues +Do not file an issue with ExecuTorch for op support. +""" + ) + self.log_once( + "Skipping op for CoreML delegation because it is not supported by CoreML: " + + node_target_name + ) + return is_supported # cowardly refuse to support all other types of node: # 1. placeholder / output nodes should not be tagged # reference: https://github.com/pytorch/executorch/pull/1398 # 2. call_module / call_method should have been replaced with call_function? else: + self.log_once( + "Skipping op for CoreML delegation because it is not get_attr or call_function: " + + node.op + ) return False @@ -62,6 +120,8 @@ def __init__( skip_ops_for_coreml_delegation: Optional[List[str]] = None, compile_specs: Optional[List[CompileSpec]] = None, take_over_mutable_buffer: Optional[bool] = True, + lower_full_graph: bool = False, + take_over_constant_data: bool = True, ) -> None: if skip_ops_for_coreml_delegation is None: skip_ops_for_coreml_delegation = [] @@ -71,6 +131,20 @@ def __init__( compile_specs=compile_specs if compile_specs is not None else [], ) self.take_over_mutable_buffer = take_over_mutable_buffer + self.lower_full_graph = lower_full_graph + self.take_over_constant_data = take_over_constant_data + self._logged_msgs = set() + + if self.lower_full_graph: + assert ( + len(self.skip_ops_for_coreml_delegation) == 0 + ), "When lower_full_graph=True, you cannot set skip_ops_for_coreml_delegation" + assert ( + self.take_over_constant_data + ), "When lower_full_graph=True, you must set take_over_constant_data=True" + assert ( + self.take_over_mutable_buffer + ), "When lower_full_graph=True, you must set take_over_mutable_buffer=True" def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible @@ -80,7 +154,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation), + OperatorsSupportedForCoreMLBackend( + self.skip_ops_for_coreml_delegation, self.lower_full_graph + ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() @@ -90,7 +166,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: node.meta["delegation_tag"] = tag partition_tags[tag] = self.delegation_spec - tag_constant_data(exported_program) + if self.take_over_constant_data: + tag_constant_data(exported_program) if self.take_over_mutable_buffer: logger.info( "Core ML partitioner will take over torch mutable buffer as Core ML state, " @@ -105,12 +182,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tagged_exported_program=exported_program, partition_tags=partition_tags ) + def log_once(self, msg: str) -> None: + if msg not in self._logged_msgs: + logging.info(msg) + self._logged_msgs.add(msg) + def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: do_not_decompose = [] - op_support = OperatorsSupportedForCoreMLBackend() - _logged_warnings = set() + op_support = OperatorsSupportedForCoreMLBackend( + self.skip_ops_for_coreml_delegation, self.lower_full_graph + ) # CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace # TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the @@ -134,9 +217,7 @@ def ops_to_not_decompose( except Exception as e: # CoreML's op_support.is_node_supported will sometimes throw # for unsupported ops, rather than returning False - warn_str = f"Encountered exception when checking node support: {e}" - if warn_str not in _logged_warnings: - logger.warning(warn_str) - _logged_warnings.add(warn_str) - + self.log_once( + f"Encountered exception when checking node support, treating node as unsupported: {e}" + ) return do_not_decompose, None diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 0a63c43414f..368a24a4a28 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -2,6 +2,8 @@ # # Please refer to the license found in the LICENSE file in the root directory of the source tree. +import copy +import sys import unittest import coremltools as ct @@ -14,6 +16,28 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner from executorch.exir.backend.utils import format_delegated_graph +from executorch.runtime import Runtime + + +@torch.library.custom_op("unsupported::linear", mutates_args=()) +def _( + x: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return torch.ops.aten.linear.default(x, w, b) + + +@torch.library.register_fake("unsupported::linear") +def _( + x: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return torch.ops.aten.linear.default(x, w, b) + + +_TEST_RUNTIME = sys.platform == "darwin" class TestCoreMLPartitioner(unittest.TestCase): @@ -200,6 +224,113 @@ def forward(self, q, k_val, input_pos): "getitem", ] + def test_lower_full_graph(self): + class Model(torch.nn.Module): + def forward(self, a, x, b): + out = torch.ops.aten.linear.default(a, x, b) + out2 = torch.ops.unsupported.linear.default(out, x, b) + return out2 + + model = Model() + model.eval() + + example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) + exir_program_aten = torch.export.export(model, example_inputs, strict=True) + edge_program_manager = executorch.exir.to_edge(exir_program_aten) + edge_program_manager2 = copy.deepcopy(edge_program_manager) + + delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner()) + + for node in delegated_program_manager.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "unsupported.linear.default", + "executorch_call_delegate", + "getitem", + ], node.target.__name__ + + with self.assertRaises(NotImplementedError): + edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True)) + + # TODO: enable this after bugs are fixed in ExecuTorch's partitioner + # def test_symint_arg(self): + # class Model(torch.nn.Module): + # def forward(self, x, w, b, y): + # val = y.item() + # torch._check(val >= 0) + # torch._check(val < 2) + # out = torch.ops.aten.linear.default(x, w, b) + # out2 = out.relu()[val] + # return out2 + + # model = Model() + # model.eval() + # example_inputs = ( + # torch.randn(2, 2), + # torch.randn(2, 2), + # torch.randn(2, 2), + # torch.tensor(2), + # ) + # exir_program_aten = torch.export.export(model, example_inputs) + + # edge_program_manager = executorch.exir.to_edge(exir_program_aten) + + # delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner(skip_ops_for_coreml_delegation=["aten.scalar_tensor.default"])) + + # # This op has symbolic args + # assert ( + # "torch.ops.aten._assert_scalar.default" + # in delegated_program_manager.exported_program().graph_module.code + # ) + + # if _TEST_RUNTIME: + # et_prog = delegated_program_manager.to_executorch() + # runtime = Runtime.get() + # program = runtime.load_program(et_prog.buffer) + # method = program.load_method("forward") + # et_outputs = method.execute(*example_inputs)[0] + # eager_outputs = model(*example_inputs) + # self.assertTrue(torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)) + + def test_take_over_constant_data_false(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(50, 100) + + def forward(self, x): + return self.linear(x) + + model = Model() + model.eval() + example_inputs = (torch.randn(2, 50),) + exir_program_aten = torch.export.export(model, example_inputs) + + edge_program_manager = executorch.exir.to_edge_transform_and_lower( + exir_program_aten, + partitioner=[CoreMLPartitioner(take_over_constant_data=False)], + ) + for node in edge_program_manager.exported_program().graph.nodes: + if ( + node.op == "call_function" + and node.target.__name__ == "executorch_call_delegate" + ): + break + + # lowered_module_0, x, p_linear_weight, p_linear_bias + assert len(node.args) == 4 + + if _TEST_RUNTIME: + et_prog = edge_program_manager.to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_prog.buffer) + method = program.load_method("forward") + et_outputs = method.execute(*example_inputs)[0] + eager_outputs = model(*example_inputs) + self.assertTrue( + torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02) + ) + if __name__ == "__main__": test_runner = TestCoreMLPartitioner() @@ -207,3 +338,6 @@ def forward(self, q, k_val, input_pos): test_runner.test_vit_skip_conv() test_runner.test_ops_to_not_decompose() test_runner.test_buffer() + test_runner.test_lower_full_graph() + # test_runner.test_symint_arg() + test_runner.test_take_over_constant_data_false()