diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 0b51c28cde8..595f8c9a5ed 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -9,7 +9,6 @@ from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa -from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa @@ -100,6 +99,7 @@ from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa from .remove_getitem_pass import RemoveGetItemPass # noqa from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 75fc529a7e1..42a223d72fd 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -14,7 +14,6 @@ AnnotateDecomposedMatmulPass, AnnotateOutputDimOrderPass, BroadcastArgsPass, - CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, ComputeConstantOpsAOTPass, @@ -92,6 +91,7 @@ InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, + PromoteBoolOperandsPass, QuantizeClampArgumentsPass, RemoveGetItemPass, RemoveGraphAssertsPass, @@ -122,7 +122,6 @@ class ArmPassManager(PassManager): - def __init__(self, tosa_spec: TosaSpecification) -> None: self.tosa_spec = tosa_spec super().__init__() @@ -217,7 +216,7 @@ def _tosa_pipeline( DecomposeEluPass(), DecomposeExpm1Pass(), DecomposeIntPowPass(), - CastBoolToInt8Pass(), + PromoteBoolOperandsPass(), DecomposeSinhPass(), DecomposeSignPass(), DecomposeFloorDividePass(), @@ -329,7 +328,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): DecomposeScaledDotProductAttentionPass(), DecomposeRoundPass(), DecomposeLogitPass(), - CastBoolToInt8Pass(), + PromoteBoolOperandsPass(), DecomposeSignPass(), DecomposeAddmmPass(), DecomposeRemainderPass(), diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py deleted file mode 100644 index 0987476a2ec..00000000000 --- a/backends/arm/_passes/cast_bool_to_int8_pass.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input -# If input/output is bool lest add a cast/conversion pass before/after to/from int8. - -from typing import Set, Type - -import torch - -from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - - -class CastBoolToInt8Pass(ArmPass): - """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" - - _passes_required_after: Set[Type[ExportPass]] = set() - - targeted_ops = { - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, - } - - def call_operator(self, op, args, kwargs, meta): - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta) - - new_args: list = [] - did_cast = False - for arg in args: - if arg.data.dtype == torch.bool: - new_args.append( - super().call_operator( - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - (arg,), - {"dtype": torch.int8}, - meta, - ) - ) - did_cast = True - else: - new_args.append(arg) - - output = super().call_operator( - op, - tuple(new_args), - {}, - meta, - ) - - if did_cast: - output = super().call_operator( - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - (output,), - {"dtype": args[0].data.dtype}, - meta, - ) - return output diff --git a/backends/arm/_passes/promote_bool_operands_pass.py b/backends/arm/_passes/promote_bool_operands_pass.py new file mode 100644 index 00000000000..8c45a808cb5 --- /dev/null +++ b/backends/arm/_passes/promote_bool_operands_pass.py @@ -0,0 +1,88 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs. +# When a targeted op receives boolean tensors, we promote them to an integer type before +# invocation and cast the result back to the expected dtype afterwards. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class PromoteBoolOperandsPass(ArmPass): + """Promote boolean operands to the appropriate integer dtype for unsupported ops.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops = { + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.mul.Tensor, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + original_dtypes = [arg.data.dtype for arg in args] + if torch.bool not in original_dtypes: + return super().call_operator(op, args, kwargs, meta) + + # select the first non-bool dtype, or None if all bool + promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None) + + # if we don't have a dtype specified by the op, promote to default choice for the op + if promoted_dtype is None: + if op == exir_ops.edge.aten.mul.Tensor: + # mul as int32 + promoted_dtype = torch.int32 + else: + # bitwise ops can be int8 + promoted_dtype = torch.int8 + + target_dtypes = [] + for dt in original_dtypes: + if dt == torch.bool: + target_dtypes.append(promoted_dtype) + else: + target_dtypes.append(dt) + + new_args = [] + for arg, original_dtype, target_dtype in zip( + args, original_dtypes, target_dtypes + ): + if original_dtype == target_dtype: + new_args.append(arg) + else: + new_args.append( + super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (arg,), + {"dtype": target_dtype}, + meta, + ) + ) + + output = super().call_operator( + op, + tuple(new_args), + kwargs, + meta, + ) + + if all(dtype == torch.bool for dtype in original_dtypes): + output = super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (output,), + {"dtype": torch.bool}, + meta, + ) + return output diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index bd093d6774e..15a105a1fd8 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -11,6 +11,7 @@ from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer, is_param +from torch.export.graph_signature import InputKind class UnsqueezeScalarPlaceholdersPass(ArmPass): @@ -42,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule): else: continue - tensor = self.exported_program.state_dict[name] + tensor = self.exported_program.state_dict.get(name) + # If we have a persistent=False buffer with no entry in state_dict + spec = next( + s + for s in self.exported_program.graph_signature.input_specs + if getattr(s.arg, "name", None) == node.name + ) + is_non_persistent_buffer = ( + spec.kind is InputKind.BUFFER and spec.persistent is False + ) + if tensor is None and is_non_persistent_buffer: + fake = node.meta["val"] + tensor = torch.ones_like(fake) + + # If we have a scalar, unsqueeze it if tensor.dim() == 0: - self.exported_program.state_dict[name] = tensor.unsqueeze(0) - node.meta["val"] = node.meta["val"].fake_mode.from_tensor( - tensor.unsqueeze(0), static_shapes=True - ) - else: - node.meta["val"] = node.meta["val"].fake_mode.from_tensor( - tensor, static_shapes=True - ) + tensor = tensor.unsqueeze(0) + + # update or create entry in state_dict, recreate fake + self.exported_program.state_dict[name] = tensor + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor, static_shapes=True + ) graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/test/passes/test_promote_bool_operands_pass.py b/backends/arm/test/passes/test_promote_bool_operands_pass.py new file mode 100644 index 00000000000..48c9778a75c --- /dev/null +++ b/backends/arm/test/passes/test_promote_bool_operands_pass.py @@ -0,0 +1,103 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import ClassVar, Dict, Tuple + +import torch +from executorch.backends.arm._passes import PromoteBoolOperandsPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops + +tensor_pair_t = Tuple[torch.Tensor, torch.Tensor] + + +def _collect_cast_dtypes(pipeline: PassPipeline[tensor_pair_t]) -> list[torch.dtype]: + exported_program = pipeline.tester.get_artifact( + StageType.RUN_PASSES + ).exported_program() + graph_module = exported_program.graph_module + cast_dtypes: list[torch.dtype] = [] + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + and "dtype" in node.kwargs + ): + cast_dtypes.append(node.kwargs["dtype"]) + return cast_dtypes + + +class BoolBitwiseAndModule(torch.nn.Module): + test_data: ClassVar[Dict[str, tensor_pair_t]] = { + "bool_tensors": ( + torch.tensor([[True, False], [False, True]], dtype=torch.bool), + torch.tensor([[False, True], [True, False]], dtype=torch.bool), + ) + } + + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + return torch.bitwise_and(lhs, rhs) + + +class MixedMulModule(torch.nn.Module): + test_data: ClassVar[Dict[str, tensor_pair_t]] = { + "mixed_tensors": ( + torch.tensor([True, False, True, False], dtype=torch.bool), + torch.tensor([1, 2, 3, 4], dtype=torch.int32), + ) + } + + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + return torch.mul(lhs, rhs) + + +@common.parametrize("test_data", BoolBitwiseAndModule.test_data) +def test_promote_bool_operands_all_bool(test_data: tensor_pair_t) -> None: + module = BoolBitwiseAndModule() + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, + } + pipeline = PassPipeline[tensor_pair_t]( + module, + test_data, + quantize=False, + ops_before_pass=ops_before_pass, + ops_after_pass=ops_after_pass, + pass_list=[PromoteBoolOperandsPass], + ) + pipeline.run() + cast_dtypes = _collect_cast_dtypes(pipeline) + assert cast_dtypes.count(torch.int8) == 2 + assert cast_dtypes.count(torch.bool) == 1 + + +@common.parametrize("test_data", MixedMulModule.test_data) +def test_promote_bool_operands_mixed_types(test_data: tensor_pair_t) -> None: + module = MixedMulModule() + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + } + pipeline = PassPipeline[tensor_pair_t]( + module, + test_data, + quantize=False, + ops_before_pass=ops_before_pass, + ops_after_pass=ops_after_pass, + pass_list=[PromoteBoolOperandsPass], + ) + pipeline.run() + cast_dtypes = _collect_cast_dtypes(pipeline) + assert cast_dtypes.count(torch.int32) == 1