-
Notifications
You must be signed in to change notification settings - Fork 741
Arm Backend: improve non-persistent placeholder and bool handling #15992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs. | ||
| # When a targeted op receives boolean tensors, we promote them to an integer type before | ||
| # invocation and cast the result back to the expected dtype afterwards. | ||
|
|
||
| from typing import Set, Type | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.backends.arm._passes.arm_pass import ArmPass | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass | ||
|
|
||
|
|
||
| class PromoteBoolOperandsPass(ArmPass): | ||
| """Promote boolean operands to the appropriate integer dtype for unsupported ops.""" | ||
|
|
||
| _passes_required_after: Set[Type[ExportPass]] = set() | ||
|
|
||
| targeted_ops = { | ||
| exir_ops.edge.aten.bitwise_and.Tensor, | ||
| exir_ops.edge.aten.bitwise_or.Tensor, | ||
| exir_ops.edge.aten.bitwise_xor.Tensor, | ||
| exir_ops.edge.aten.mul.Tensor, | ||
| } | ||
|
|
||
| def call_operator(self, op, args, kwargs, meta): | ||
| if op not in self.targeted_ops: | ||
| return super().call_operator(op, args, kwargs, meta) | ||
|
|
||
| original_dtypes = [arg.data.dtype for arg in args] | ||
| if torch.bool not in original_dtypes: | ||
| return super().call_operator(op, args, kwargs, meta) | ||
|
|
||
| # select the first non-bool dtype, or None if all bool | ||
| promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None) | ||
|
|
||
| # if we don't have a dtype specified by the op, promote to default choice for the op | ||
| if promoted_dtype is None: | ||
| if op == exir_ops.edge.aten.mul.Tensor: | ||
| # mul as int32 | ||
| promoted_dtype = torch.int32 | ||
| else: | ||
| # bitwise ops can be int8 | ||
| promoted_dtype = torch.int8 | ||
|
|
||
| target_dtypes = [] | ||
| for dt in original_dtypes: | ||
| if dt == torch.bool: | ||
| target_dtypes.append(promoted_dtype) | ||
| else: | ||
| target_dtypes.append(dt) | ||
|
|
||
| new_args = [] | ||
| for arg, original_dtype, target_dtype in zip( | ||
| args, original_dtypes, target_dtypes | ||
| ): | ||
| if original_dtype == target_dtype: | ||
| new_args.append(arg) | ||
| else: | ||
| new_args.append( | ||
| super().call_operator( | ||
| exir_ops.edge.dim_order_ops._to_dim_order_copy.default, | ||
| (arg,), | ||
| {"dtype": target_dtype}, | ||
| meta, | ||
| ) | ||
| ) | ||
|
|
||
| output = super().call_operator( | ||
| op, | ||
| tuple(new_args), | ||
| kwargs, | ||
| meta, | ||
| ) | ||
|
|
||
| if all(dtype == torch.bool for dtype in original_dtypes): | ||
| output = super().call_operator( | ||
| exir_ops.edge.dim_order_ops._to_dim_order_copy.default, | ||
| (output,), | ||
| {"dtype": torch.bool}, | ||
| meta, | ||
| ) | ||
| return output | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |||||||||||||||||||||||||||
| from executorch.exir import ExportedProgram | ||||||||||||||||||||||||||||
| from executorch.exir.pass_base import ExportPass, PassResult | ||||||||||||||||||||||||||||
| from torch._export.utils import is_buffer, is_param | ||||||||||||||||||||||||||||
| from torch.export.graph_signature import InputKind | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class UnsqueezeScalarPlaceholdersPass(ArmPass): | ||||||||||||||||||||||||||||
|
|
@@ -42,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule): | |||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| tensor = self.exported_program.state_dict[name] | ||||||||||||||||||||||||||||
| tensor = self.exported_program.state_dict.get(name) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # If we have a persistent=False buffer with no entry in state_dict | ||||||||||||||||||||||||||||
| spec = next( | ||||||||||||||||||||||||||||
| s | ||||||||||||||||||||||||||||
| for s in self.exported_program.graph_signature.input_specs | ||||||||||||||||||||||||||||
| if getattr(s.arg, "name", None) == node.name | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
Comment on lines
+50
to
+53
|
||||||||||||||||||||||||||||
| s | |
| for s in self.exported_program.graph_signature.input_specs | |
| if getattr(s.arg, "name", None) == node.name | |
| ) | |
| (s | |
| for s in self.exported_program.graph_signature.input_specs | |
| if getattr(s.arg, "name", None) == node.name), | |
| None | |
| ) | |
| if spec is None: | |
| raise ValueError( | |
| f"No matching input spec found for placeholder node '{node.name}'." | |
| ) |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a comment explaining why torch.ones_like is used here for non-persistent buffers. The choice of ones (versus zeros or random values) may not be immediately clear to future maintainers.
| fake = node.meta["val"] | |
| fake = node.meta["val"] | |
| # For non-persistent buffers, we initialize with ones to provide a consistent, non-zero default. | |
| # This avoids potential issues with zeros being interpreted as "empty" and ensures the buffer is set. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import ClassVar, Dict, Tuple | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes import PromoteBoolOperandsPass | ||
|
|
||
| from executorch.backends.arm.test import common | ||
| from executorch.backends.arm.test.tester.test_pipeline import PassPipeline | ||
| from executorch.backends.test.harness.stages import StageType | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
| tensor_pair_t = Tuple[torch.Tensor, torch.Tensor] | ||
|
|
||
|
|
||
| def _collect_cast_dtypes(pipeline: PassPipeline[tensor_pair_t]) -> list[torch.dtype]: | ||
| exported_program = pipeline.tester.get_artifact( | ||
| StageType.RUN_PASSES | ||
| ).exported_program() | ||
| graph_module = exported_program.graph_module | ||
| cast_dtypes: list[torch.dtype] = [] | ||
| for node in graph_module.graph.nodes: | ||
| if ( | ||
| node.op == "call_function" | ||
| and node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default | ||
| and "dtype" in node.kwargs | ||
| ): | ||
| cast_dtypes.append(node.kwargs["dtype"]) | ||
| return cast_dtypes | ||
|
|
||
|
|
||
| class BoolBitwiseAndModule(torch.nn.Module): | ||
| test_data: ClassVar[Dict[str, tensor_pair_t]] = { | ||
| "bool_tensors": ( | ||
| torch.tensor([[True, False], [False, True]], dtype=torch.bool), | ||
| torch.tensor([[False, True], [True, False]], dtype=torch.bool), | ||
| ) | ||
| } | ||
|
|
||
| def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: | ||
| return torch.bitwise_and(lhs, rhs) | ||
|
|
||
|
|
||
| class MixedMulModule(torch.nn.Module): | ||
| test_data: ClassVar[Dict[str, tensor_pair_t]] = { | ||
| "mixed_tensors": ( | ||
| torch.tensor([True, False, True, False], dtype=torch.bool), | ||
| torch.tensor([1, 2, 3, 4], dtype=torch.int32), | ||
| ) | ||
| } | ||
|
|
||
| def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: | ||
| return torch.mul(lhs, rhs) | ||
|
|
||
|
|
||
| @common.parametrize("test_data", BoolBitwiseAndModule.test_data) | ||
| def test_promote_bool_operands_all_bool(test_data: tensor_pair_t) -> None: | ||
| module = BoolBitwiseAndModule() | ||
| ops_before_pass = { | ||
| "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, | ||
| } | ||
| ops_after_pass = { | ||
| "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, | ||
| } | ||
| pipeline = PassPipeline[tensor_pair_t]( | ||
| module, | ||
| test_data, | ||
| quantize=False, | ||
| ops_before_pass=ops_before_pass, | ||
| ops_after_pass=ops_after_pass, | ||
| pass_list=[PromoteBoolOperandsPass], | ||
| ) | ||
| pipeline.run() | ||
| cast_dtypes = _collect_cast_dtypes(pipeline) | ||
| assert cast_dtypes.count(torch.int8) == 2 | ||
| assert cast_dtypes.count(torch.bool) == 1 | ||
|
|
||
|
|
||
| @common.parametrize("test_data", MixedMulModule.test_data) | ||
| def test_promote_bool_operands_mixed_types(test_data: tensor_pair_t) -> None: | ||
| module = MixedMulModule() | ||
| ops_before_pass = { | ||
| "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, | ||
| } | ||
| ops_after_pass = { | ||
| "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, | ||
| } | ||
| pipeline = PassPipeline[tensor_pair_t]( | ||
| module, | ||
| test_data, | ||
| quantize=False, | ||
| ops_before_pass=ops_before_pass, | ||
| ops_after_pass=ops_after_pass, | ||
| pass_list=[PromoteBoolOperandsPass], | ||
| ) | ||
| pipeline.run() | ||
| cast_dtypes = _collect_cast_dtypes(pipeline) | ||
| assert cast_dtypes.count(torch.int32) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment 'mul as int32' is unclear. Consider clarifying why mul operations specifically require int32 promotion (e.g., 'promote mul operations to int32 to match expected behavior' or 'mul requires int32 to handle result range').