Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
from .broadcast_args_pass import BroadcastArgsPass # noqa
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
from .cast_to_int32_pass import CastToInt32Pass # noqa
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
Expand Down Expand Up @@ -100,6 +99,7 @@
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa
from .remove_getitem_pass import RemoveGetItemPass # noqa
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
from .remove_noop_pass import RemoveNoopPass # noqa
Expand Down
7 changes: 3 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
AnnotateDecomposedMatmulPass,
AnnotateOutputDimOrderPass,
BroadcastArgsPass,
CastBoolToInt8Pass,
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOTPass,
Expand Down Expand Up @@ -92,6 +91,7 @@
InsertTableOpsPass,
MatchArgDtypePass,
MatchArgRanksPass,
PromoteBoolOperandsPass,
QuantizeClampArgumentsPass,
RemoveGetItemPass,
RemoveGraphAssertsPass,
Expand Down Expand Up @@ -122,7 +122,6 @@


class ArmPassManager(PassManager):

def __init__(self, tosa_spec: TosaSpecification) -> None:
self.tosa_spec = tosa_spec
super().__init__()
Expand Down Expand Up @@ -217,7 +216,7 @@ def _tosa_pipeline(
DecomposeEluPass(),
DecomposeExpm1Pass(),
DecomposeIntPowPass(),
CastBoolToInt8Pass(),
PromoteBoolOperandsPass(),
DecomposeSinhPass(),
DecomposeSignPass(),
DecomposeFloorDividePass(),
Expand Down Expand Up @@ -329,7 +328,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
DecomposeScaledDotProductAttentionPass(),
DecomposeRoundPass(),
DecomposeLogitPass(),
CastBoolToInt8Pass(),
PromoteBoolOperandsPass(),
DecomposeSignPass(),
DecomposeAddmmPass(),
DecomposeRemainderPass(),
Expand Down
63 changes: 0 additions & 63 deletions backends/arm/_passes/cast_bool_to_int8_pass.py

This file was deleted.

88 changes: 88 additions & 0 deletions backends/arm/_passes/promote_bool_operands_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
# When a targeted op receives boolean tensors, we promote them to an integer type before
# invocation and cast the result back to the expected dtype afterwards.

from typing import Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class PromoteBoolOperandsPass(ArmPass):
"""Promote boolean operands to the appropriate integer dtype for unsupported ops."""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.mul.Tensor,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

original_dtypes = [arg.data.dtype for arg in args]
if torch.bool not in original_dtypes:
return super().call_operator(op, args, kwargs, meta)

# select the first non-bool dtype, or None if all bool
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)

# if we don't have a dtype specified by the op, promote to default choice for the op
if promoted_dtype is None:
if op == exir_ops.edge.aten.mul.Tensor:
# mul as int32
promoted_dtype = torch.int32
else:
# bitwise ops can be int8
promoted_dtype = torch.int8
Comment on lines +44 to +49
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'mul as int32' is unclear. Consider clarifying why mul operations specifically require int32 promotion (e.g., 'promote mul operations to int32 to match expected behavior' or 'mul requires int32 to handle result range').

Copilot uses AI. Check for mistakes.

target_dtypes = []
for dt in original_dtypes:
if dt == torch.bool:
target_dtypes.append(promoted_dtype)
else:
target_dtypes.append(dt)

new_args = []
for arg, original_dtype, target_dtype in zip(
args, original_dtypes, target_dtypes
):
if original_dtype == target_dtype:
new_args.append(arg)
else:
new_args.append(
super().call_operator(
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
(arg,),
{"dtype": target_dtype},
meta,
)
)

output = super().call_operator(
op,
tuple(new_args),
kwargs,
meta,
)

if all(dtype == torch.bool for dtype in original_dtypes):
output = super().call_operator(
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
(output,),
{"dtype": torch.bool},
meta,
)
return output
32 changes: 23 additions & 9 deletions backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import is_buffer, is_param
from torch.export.graph_signature import InputKind


class UnsqueezeScalarPlaceholdersPass(ArmPass):
Expand Down Expand Up @@ -42,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule):
else:
continue

tensor = self.exported_program.state_dict[name]
tensor = self.exported_program.state_dict.get(name)

# If we have a persistent=False buffer with no entry in state_dict
spec = next(
s
for s in self.exported_program.graph_signature.input_specs
if getattr(s.arg, "name", None) == node.name
)
Comment on lines +50 to +53
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next() call without a default argument will raise StopIteration if no matching spec is found. Consider adding a default value or explicit error handling to provide a clearer error message if the spec is not found.

Suggested change
s
for s in self.exported_program.graph_signature.input_specs
if getattr(s.arg, "name", None) == node.name
)
(s
for s in self.exported_program.graph_signature.input_specs
if getattr(s.arg, "name", None) == node.name),
None
)
if spec is None:
raise ValueError(
f"No matching input spec found for placeholder node '{node.name}'."
)

Copilot uses AI. Check for mistakes.
is_non_persistent_buffer = (
spec.kind is InputKind.BUFFER and spec.persistent is False
)
if tensor is None and is_non_persistent_buffer:
fake = node.meta["val"]
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment explaining why torch.ones_like is used here for non-persistent buffers. The choice of ones (versus zeros or random values) may not be immediately clear to future maintainers.

Suggested change
fake = node.meta["val"]
fake = node.meta["val"]
# For non-persistent buffers, we initialize with ones to provide a consistent, non-zero default.
# This avoids potential issues with zeros being interpreted as "empty" and ensures the buffer is set.

Copilot uses AI. Check for mistakes.
tensor = torch.ones_like(fake)

# If we have a scalar, unsqueeze it
if tensor.dim() == 0:
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor.unsqueeze(0), static_shapes=True
)
else:
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor, static_shapes=True
)
tensor = tensor.unsqueeze(0)

# update or create entry in state_dict, recreate fake
self.exported_program.state_dict[name] = tensor
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor, static_shapes=True
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
103 changes: 103 additions & 0 deletions backends/arm/test/passes/test_promote_bool_operands_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import ClassVar, Dict, Tuple

import torch
from executorch.backends.arm._passes import PromoteBoolOperandsPass

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.test.harness.stages import StageType
from executorch.exir.dialects._ops import ops as exir_ops

tensor_pair_t = Tuple[torch.Tensor, torch.Tensor]


def _collect_cast_dtypes(pipeline: PassPipeline[tensor_pair_t]) -> list[torch.dtype]:
exported_program = pipeline.tester.get_artifact(
StageType.RUN_PASSES
).exported_program()
graph_module = exported_program.graph_module
cast_dtypes: list[torch.dtype] = []
for node in graph_module.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default
and "dtype" in node.kwargs
):
cast_dtypes.append(node.kwargs["dtype"])
return cast_dtypes


class BoolBitwiseAndModule(torch.nn.Module):
test_data: ClassVar[Dict[str, tensor_pair_t]] = {
"bool_tensors": (
torch.tensor([[True, False], [False, True]], dtype=torch.bool),
torch.tensor([[False, True], [True, False]], dtype=torch.bool),
)
}

def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
return torch.bitwise_and(lhs, rhs)


class MixedMulModule(torch.nn.Module):
test_data: ClassVar[Dict[str, tensor_pair_t]] = {
"mixed_tensors": (
torch.tensor([True, False, True, False], dtype=torch.bool),
torch.tensor([1, 2, 3, 4], dtype=torch.int32),
)
}

def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
return torch.mul(lhs, rhs)


@common.parametrize("test_data", BoolBitwiseAndModule.test_data)
def test_promote_bool_operands_all_bool(test_data: tensor_pair_t) -> None:
module = BoolBitwiseAndModule()
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
}
pipeline = PassPipeline[tensor_pair_t](
module,
test_data,
quantize=False,
ops_before_pass=ops_before_pass,
ops_after_pass=ops_after_pass,
pass_list=[PromoteBoolOperandsPass],
)
pipeline.run()
cast_dtypes = _collect_cast_dtypes(pipeline)
assert cast_dtypes.count(torch.int8) == 2
assert cast_dtypes.count(torch.bool) == 1


@common.parametrize("test_data", MixedMulModule.test_data)
def test_promote_bool_operands_mixed_types(test_data: tensor_pair_t) -> None:
module = MixedMulModule()
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
}
pipeline = PassPipeline[tensor_pair_t](
module,
test_data,
quantize=False,
ops_before_pass=ops_before_pass,
ops_after_pass=ops_after_pass,
pass_list=[PromoteBoolOperandsPass],
)
pipeline.run()
cast_dtypes = _collect_cast_dtypes(pipeline)
assert cast_dtypes.count(torch.int32) == 1