Skip to content

Commit e13b086

Browse files
Arm backend: Merge decompose/convert meandim pass (#10844)
This change has multiple benifits: - Cleaner arm_pass_manager - Use more efficient avgpool2d decompostion in more cases - Fixes a bug decomposing to avgpool for rank != 4 Note that symmetric_io_quantization is required for the unittests resulting in only a avgpool op becuase of the way avgpool is annotated. Signed-off-by: Adrian Lundell <[email protected]>
1 parent d0360b7 commit e13b086

File tree

9 files changed

+272
-216
lines changed

9 files changed

+272
-216
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
4848
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
4949
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
50-
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
5150
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
5251
from .remove_clone_pass import RemoveClonePass # noqa
5352
from .replace_scalar_with_tensor_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
ConvertAnyDefaultDimDimsPass,
1818
ConvertExpandCopyToRepeatPass,
1919
ConvertFullLikeToFullPass,
20-
ConvertMeanDimToAveragePoolPass,
2120
ConvertMinMaxPass,
2221
ConvertMmToBmmPass,
2322
ConvertSplitToSlicePass,
@@ -87,7 +86,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8786
self.add_pass(ConvertSplitToSlicePass())
8887
self.add_pass(ConvertMmToBmmPass())
8988
self.add_pass(DecomposeLinearPass())
90-
self.add_pass(ConvertMeanDimToAveragePoolPass())
89+
self.add_pass(DecomposeMeanDimPass())
9190
self.add_pass(ConvertFullLikeToFullPass())
9291
self.add_pass(ConvertToClampPass())
9392
self.add_pass(ConvertMinMaxPass())
@@ -140,7 +139,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
140139
self.add_pass(DecomposeVarPass())
141140
self.add_pass(DecomposeMeanDimPass())
142141
self.add_pass(DecomposeNotEqualPass())
143-
self.add_pass(ConvertMeanDimToAveragePoolPass())
144142
self.add_pass(DecomposeDivPass())
145143
self.add_pass(DecomposeSoftmaxPass())
146144
self.add_pass(DecomposeGeluPass())
Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

7-
# pyre-unsafe
6+
from math import prod
87

98
import torch
109
from executorch.backends.arm._passes import ArmPass
@@ -28,42 +27,111 @@ def get_meandim_decomposition(op) -> tuple:
2827
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
2928

3029

30+
def get_avgpool(op):
31+
if op == exir_ops.edge.aten.mean.dim:
32+
return exir_ops.edge.aten.avg_pool2d.default
33+
if op == torch.ops.aten.mean.dim:
34+
return torch.ops.aten.avg_pool2d.default
35+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
36+
37+
38+
def get_view(op):
39+
if op == exir_ops.edge.aten.mean.dim:
40+
return exir_ops.edge.aten.view_copy.default
41+
if op == torch.ops.aten.mean.dim:
42+
return torch.ops.aten.view_copy.default
43+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
44+
45+
3146
class DecomposeMeanDimPass(ArmPass):
3247
"""
33-
This pass decomposes meandim into a sum and mul node.
48+
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
49+
h,w -> avg_pool
50+
n,c -> sum + mul(1/N)
51+
For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left.
3452
3553
Example:
36-
y = mean_dim(x, dim, keepdim)
54+
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
3755
Becomes:
38-
sum = sum.dim_IntList(x, dim, keepdim)
39-
y = mul(sum, 1/N)
56+
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
57+
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
58+
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
59+
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
60+
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
4061
"""
4162

4263
def call_operator(self, op, args, kwargs, meta):
4364
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
4465
return super().call_operator(op, args, kwargs, meta)
4566

4667
x = get_node_arg(args, 0)
47-
dim = get_node_arg(args, 1)
48-
keepdim = get_node_arg(args, 2, False)
49-
50-
# if dim == [-1, -2], mean.dim can be
51-
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
52-
if dim == [-1, -2]:
53-
# Simply return the mean.dim operator for future decomposition.
54-
return super().call_operator(op, args, kwargs, meta)
68+
input_shape = x.data.size()
69+
output_shape = meta["val"].size()
70+
dims_to_reduce = get_node_arg(args, 1)
71+
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
5572

56-
shape = meta["val"].size()
5773
dtype = meta["val"].dtype
58-
input_shape = x.data.size()
59-
N = 1
60-
for d in dim:
61-
N *= input_shape[d]
74+
view_op = get_view(op)
6275

76+
if len(input_shape) > 4:
77+
raise NotImplementedError(
78+
f"{op} with rank > 4 is currently not supported for the TOSA backend."
79+
)
80+
81+
# Unsqueeze to 4D
82+
if len(input_shape) < 4:
83+
pad_n = 4 - len(input_shape)
84+
new_shape = [1] * pad_n + list(input_shape)
85+
dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]
86+
87+
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
88+
89+
# Reduce (h,w) by avg pool
90+
dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2]
91+
x = self._reduce_by_average_pool(op, x, dims_to_reduce_by_avgpool, meta)
92+
93+
# Reduce (n, c) by reduce sum
94+
dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2]
95+
x = self._reduce_by_sum(op, x, dims_to_reduce_by_sum, meta, dtype)
96+
97+
# Reshape to correct output shape if necessary
98+
if x.data.size() != output_shape:
99+
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
100+
101+
return x
102+
103+
def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
104+
if len(dims) == 0:
105+
return input_node
106+
107+
input_shape = input_node.data.size()
108+
output_shape = meta["val"].size()
109+
N = prod((n for i, n in enumerate(input_shape) if i in dims))
63110
sum_op, full_op, mul_op = get_meandim_decomposition(op)
64111

65-
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
112+
sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
66113
full = super().call_operator(
67-
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
114+
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
68115
)
69116
return super().call_operator(mul_op, (sum, full), {}, meta, True)
117+
118+
def _reduce_by_average_pool(self, op, input_node, dims, meta):
119+
if len(dims) == 0:
120+
return input_node
121+
122+
avgpool_op = get_avgpool(op)
123+
input_shape = input_node.data.size()
124+
125+
stride = [1, 1]
126+
if dims in ([2, 3], [3, 2]):
127+
kernel_size = [input_shape[2], input_shape[3]]
128+
elif dims == [3]:
129+
kernel_size = [1, input_shape[3]]
130+
elif dims == [2]:
131+
kernel_size = [input_shape[2], 1]
132+
else:
133+
raise RuntimeError(f"Bad dims {dims} for {op} decomposition of mean_dim.")
134+
135+
return super().call_operator(
136+
avgpool_op, (input_node, kernel_size, stride), {}, meta, True
137+
)

backends/arm/_passes/meandim_to_averagepool_pass.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -262,28 +262,23 @@ def is_node_supported(
262262

263263
if node.op != "call_function":
264264
return True
265-
if node.target == exir_ops.edge.aten.mean.dim:
266-
dim = node.args[1]
267-
needs_decomp = dim != [-1, -2]
268-
else:
269-
needs_decomp = node.target in [
270-
exir_ops.edge.aten.div.Tensor,
271-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
272-
exir_ops.edge.aten.native_layer_norm.default,
273-
exir_ops.edge.aten.mean.dim,
274-
exir_ops.edge.aten._softmax.default,
275-
exir_ops.edge.aten._log_softmax.default,
276-
exir_ops.edge.aten.var.correction,
277-
exir_ops.edge.aten.var.dim,
278-
exir_ops.edge.aten.add.Scalar,
279-
exir_ops.edge.aten.sqrt.default,
280-
exir_ops.edge.aten.sub.Scalar,
281-
exir_ops.edge.aten.mul.Scalar,
282-
exir_ops.edge.aten.ne.Tensor,
283-
exir_ops.edge.aten.ne.Scalar,
284-
exir_ops.edge.aten.div.Scalar,
285-
exir_ops.edge.aten.leaky_relu.default,
286-
]
265+
needs_decomp = node.target in [
266+
exir_ops.edge.aten.div.Tensor,
267+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
268+
exir_ops.edge.aten.native_layer_norm.default,
269+
exir_ops.edge.aten._softmax.default,
270+
exir_ops.edge.aten._log_softmax.default,
271+
exir_ops.edge.aten.var.correction,
272+
exir_ops.edge.aten.var.dim,
273+
exir_ops.edge.aten.add.Scalar,
274+
exir_ops.edge.aten.sqrt.default,
275+
exir_ops.edge.aten.sub.Scalar,
276+
exir_ops.edge.aten.mul.Scalar,
277+
exir_ops.edge.aten.ne.Tensor,
278+
exir_ops.edge.aten.ne.Scalar,
279+
exir_ops.edge.aten.div.Scalar,
280+
exir_ops.edge.aten.leaky_relu.default,
281+
]
287282
if needs_decomp:
288283
self.reporter.report_reject(node, "Needs to be decomposed.")
289284
return False

backends/arm/test/ops/test_layer_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_native_layer_norm_tosa_BI(test_data):
8181
model,
8282
test_data,
8383
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
84+
symmetric_io_quantization=True,
8485
)
85-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
8686
pipeline.run()
8787

8888

@@ -95,8 +95,8 @@ def test_native_layer_norm_u55_BI(test_data):
9595
test_data,
9696
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
9797
run_on_fvp=True,
98+
symmetric_io_quantization=True,
9899
)
99-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
100100
pipeline.run()
101101

102102

@@ -109,6 +109,6 @@ def test_native_layer_norm_u85_BI(test_data):
109109
test_data,
110110
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
111111
run_on_fvp=True,
112+
symmetric_io_quantization=True,
112113
)
113-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
114114
pipeline.run()

0 commit comments

Comments
 (0)