Skip to content

Commit d59fddc

Browse files
authored
Arm backend: Support ScalarType::Bool in EthosUBackend (#11850)
### Summary Add better support for booleans in the Arm backend. ### Test plan There is a lot of unit tests to test this in a few ops that has been enabled. Signed-off-by: Zingo Andersen <[email protected]>
1 parent 4cb71a0 commit d59fddc

File tree

11 files changed

+402
-89
lines changed

11 files changed

+402
-89
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
1111
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1212
from .broadcast_args_pass import BroadcastArgsPass # noqa
13+
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1314
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1415
from .cast_to_int32_pass import CastToInt32Pass # noqa
1516
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AnnotateChannelsLastDimOrder,
1212
AnnotateDecomposedMatmulPass,
1313
BroadcastArgsPass,
14+
CastBoolToInt8Pass,
1415
CastInt64BuffersToInt32Pass,
1516
CastToInt32Pass,
1617
ComputeConstantOpsAOT,
@@ -108,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
108109
if self.tosa_spec.is_U55_subset:
109110
self.add_pass(CastToInt32Pass())
110111

112+
self.add_pass(CastBoolToInt8Pass())
111113
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
112114
self.add_pass(AnnotateDecomposedMatmulPass())
113115
self.add_pass(QuantizeOperatorArguments())
@@ -148,6 +150,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
148150
self.add_pass(DecomposeRoundPass())
149151
self.add_pass(DecomposeSqrtPass())
150152
self.add_pass(ConvertIntPowToMuls())
153+
self.add_pass(CastBoolToInt8Pass())
151154
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
152155
self.add_pass(DecomposeEmbeddingPass())
153156
self.add_pass(FuseQuantizedActivationPass())
@@ -230,6 +233,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
230233
self.add_pass(DecomposeEmbeddingPass())
231234
self.add_pass(DecomposeScaledDotProductAttention())
232235
self.add_pass(DecomposeRoundPass())
236+
self.add_pass(CastBoolToInt8Pass())
233237
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
234238
self.add_pass(ScalarsToAttributePass())
235239
self.add_pass(DecomposeGroupNormPass())
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input
7+
# If input/output is bool lest add a cast/conversion pass before/after to/from int8.
8+
9+
import torch
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
class CastBoolToInt8Pass(ExportPass):
16+
"""Casts the input to int8 if it is not already and casts back the output to the original input dtype."""
17+
18+
targeted_ops = {
19+
exir_ops.edge.aten.bitwise_and.Tensor,
20+
exir_ops.edge.aten.bitwise_or.Tensor,
21+
exir_ops.edge.aten.bitwise_xor.Tensor,
22+
}
23+
24+
def call_operator(self, op, args, kwargs, meta):
25+
if op not in self.targeted_ops:
26+
return super().call_operator(op, args, kwargs, meta)
27+
28+
new_args: list = []
29+
did_cast = False
30+
for arg in args:
31+
if arg.data.dtype == torch.bool:
32+
new_args.append(
33+
super().call_operator(
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
35+
(arg,),
36+
{"dtype": torch.int8},
37+
meta,
38+
)
39+
)
40+
did_cast = True
41+
else:
42+
new_args.append(arg)
43+
44+
output = super().call_operator(
45+
op,
46+
tuple(new_args),
47+
{},
48+
meta,
49+
)
50+
51+
if did_cast:
52+
output = super().call_operator(
53+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
54+
(output,),
55+
{"dtype": args[0].data.dtype},
56+
meta,
57+
)
58+
return output

backends/arm/operators/ops_binary.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.backends.arm.operators.operator_validation_utils import (
1818
validate_num_inputs,
1919
validate_same_dtype,
20+
validate_valid_dtype,
2021
)
2122
from executorch.backends.arm.tosa_mapping import TosaArg
2223

@@ -40,6 +41,30 @@ def define_node(
4041
validate_num_inputs(self.target, inputs, 2)
4142
validate_same_dtype(self.target, [*inputs, output], ts)
4243

44+
if self.target in [
45+
"aten.bitwise_and.Tensor",
46+
"aten.bitwise_xor.Tensor",
47+
"aten.bitwise_or.Tensor",
48+
"aten.bitwise_left_shift.Tensor",
49+
]:
50+
validate_valid_dtype(
51+
self.target,
52+
[*inputs, output],
53+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
54+
output.tosa_spec,
55+
)
56+
if self.target in [
57+
"aten.logical_and.default",
58+
"aten.logical_xor.defaul",
59+
"aten.logical_or.default",
60+
]:
61+
validate_valid_dtype(
62+
self.target,
63+
[*inputs, output],
64+
[ts.DType.BOOL],
65+
output.tosa_spec,
66+
)
67+
4368
tosa_graph.addOperator(
4469
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
4570
)
@@ -66,6 +91,30 @@ def define_node(
6691
validate_num_inputs(self.target, inputs, 2)
6792
validate_same_dtype(self.target, [*inputs, output], ts)
6893

94+
if self.target in [
95+
"aten.bitwise_and.Tensor",
96+
"aten.bitwise_xor.Tensor",
97+
"aten.bitwise_or.Tensor",
98+
"aten.bitwise_left_shift.Tensor",
99+
]:
100+
validate_valid_dtype(
101+
self.target,
102+
[*inputs, output],
103+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
104+
output.tosa_spec,
105+
)
106+
if self.target in [
107+
"aten.logical_and.default",
108+
"aten.logical_xor.defaul",
109+
"aten.logical_or.default",
110+
]:
111+
validate_valid_dtype(
112+
self.target,
113+
[*inputs, output],
114+
[ts.DType.BOOL],
115+
output.tosa_spec,
116+
)
117+
69118
tosa_graph.addOperator(
70119
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
71120
)

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
234234
supported |=
235235
(tensor_in.scalar_type() == ScalarType::Short and
236236
handles.inputs->io[i].elem_size == 2);
237+
// bool (IOQDQ pass prepared networks)
238+
supported |=
239+
(tensor_in.scalar_type() == ScalarType::Bool and
240+
handles.inputs->io[i].elem_size == 1);
237241
if (!supported) {
238242
ET_LOG(
239243
Error,
240-
"Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s",
244+
"Input %d expected Integer (4 byte), Char (1 byte) or Bool (1 byte) integer inputs, got ScalarType id %s size %d",
241245
i,
242-
executorch::runtime::toString(tensor_in.scalar_type()));
246+
executorch::runtime::toString(tensor_in.scalar_type()),
247+
handles.inputs->io[i].elem_size);
243248
return Error::InvalidProgram;
244249
}
245250
supported = executorch::runtime::is_contiguous_dim_order(
@@ -257,15 +262,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
257262
bool permuted_input_shape;
258263
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
259264
i, tensor_in, &handles.inputs->io[i], &permuted_input_shape));
260-
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
261-
handles.inputs->io[i].elem_size == 1;
262-
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
265+
bool both_int = tensor_in.scalar_type() == ScalarType::Int &&
263266
handles.inputs->io[i].elem_size == 4;
264-
bool both_short = tensor_in.scalar_type() == ScalarType::Short and
267+
bool both_char = tensor_in.scalar_type() == ScalarType::Char &&
268+
handles.inputs->io[i].elem_size == 1;
269+
bool both_short = tensor_in.scalar_type() == ScalarType::Short &&
265270
handles.inputs->io[i].elem_size == 2;
271+
bool both_bool = tensor_in.scalar_type() == ScalarType::Bool &&
272+
(handles.inputs->io[i].elem_size == 1);
266273

267274
// Select a compatible copy routine
268-
if (both_char && permuted_input_shape) {
275+
if ((both_char || both_bool) && permuted_input_shape) {
269276
EXECUTORCH_PROF_SCOPE(
270277
event_tracer,
271278
"+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()");
@@ -276,7 +283,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
276283
tensor_in.size(1),
277284
tensor_in.size(2),
278285
tensor_in.size(3));
279-
} else if (both_char || both_int || both_short) {
286+
} else if (both_char || both_int || both_short || both_bool) {
280287
EXECUTORCH_PROF_SCOPE(
281288
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
282289
// Sizes match and elt size matches so memcpy
@@ -363,7 +370,9 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
363370
bool permuted_output_shape;
364371
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
365372
i, tensor_out, &handles.outputs->io[i], &permuted_output_shape));
366-
if (tensor_out.scalar_type() == ScalarType::Char &&
373+
374+
if ((tensor_out.scalar_type() == ScalarType::Char ||
375+
tensor_out.scalar_type() == ScalarType::Bool) &&
367376
permuted_output_shape) {
368377
EXECUTORCH_PROF_SCOPE(
369378
event_tracer,
@@ -379,17 +388,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
379388
tensor_out.size(3));
380389
} else {
381390
EXECUTORCH_PROF_SCOPE(
382-
event_tracer, "+EthosUBackend::execute()handles.output.move()");
383-
for (int j = 0; j < tensor_out.numel(); j++) {
384-
if (tensor_out.scalar_type() == ScalarType::Char) {
385-
const char* output_address = static_cast<const char*>(output_addr);
386-
tensor_out.mutable_data_ptr<char>()[j] = output_address[j];
387-
} else {
388-
const int* output_address =
389-
reinterpret_cast<const int*>(output_addr);
390-
tensor_out.mutable_data_ptr<int>()[j] = output_address[j];
391-
}
392-
}
391+
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");
392+
393+
memcpy(
394+
tensor_out.mutable_data_ptr<char>(),
395+
static_cast<const char*>(output_addr),
396+
tensor_out.nbytes());
393397
}
394398
}
395399
if (tensor_dim != io_dim) {

backends/arm/test/ops/test_any.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import List, Tuple
88

9-
import pytest
109
import torch
1110
from executorch.backends.arm.test import common
1211
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -125,14 +124,30 @@ def forward(self, x: torch.Tensor):
125124
@common.parametrize("test_data", test_data)
126125
def test_any_tosa_MI(test_data: input_t1):
127126
op, test_input = test_data()
128-
pipeline = TosaPipelineMI[input_t1](op, test_input(), op.aten_op, op.exir_op)
127+
pipeline = TosaPipelineMI[input_t1](
128+
op,
129+
test_input(),
130+
op.aten_op,
131+
op.exir_op,
132+
atol=0,
133+
rtol=0,
134+
qtol=0,
135+
)
129136
pipeline.run()
130137

131138

132139
@common.parametrize("test_data", test_data)
133140
def test_any_tosa_BI(test_data: input_t1):
134141
op, test_input = test_data()
135-
pipeline = TosaPipelineBI[input_t1](op, test_input(), op.aten_op, op.exir_op)
142+
pipeline = TosaPipelineBI[input_t1](
143+
op,
144+
test_input(),
145+
op.aten_op,
146+
op.exir_op,
147+
atol=0,
148+
rtol=0,
149+
qtol=0,
150+
)
136151
pipeline.pop_stage("quantize")
137152
pipeline.pop_stage("check.quant_nodes")
138153
pipeline.run()
@@ -153,7 +168,6 @@ def test_any_u55_BI(test_data: input_t1):
153168

154169

155170
@common.parametrize("test_data", test_data)
156-
@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.")
157171
@common.XfailIfNoCorstone320
158172
def test_any_u85_BI(test_data: input_t1):
159173
op, test_input = test_data()
@@ -163,6 +177,9 @@ def test_any_u85_BI(test_data: input_t1):
163177
op.aten_op,
164178
op.exir_op,
165179
run_on_fvp=True,
180+
atol=0,
181+
rtol=0,
182+
qtol=0,
166183
)
167184
pipeline.pop_stage("quantize")
168185
pipeline.pop_stage("check.quant_nodes")

0 commit comments

Comments
 (0)