Skip to content

Commit d147a2c

Browse files
Arm backend: Fix decompose_meandim_pass bug (#11141)
Previously this pass could insert avg_pool ops with configurations not supported for Ethos-U55. Check this using the AvgPool2dSupported checker and decompose fully using sum for non supported cases. Also modifies the test cases to test this, +more varied shapes. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 087a27c commit d147a2c

File tree

5 files changed

+153
-77
lines changed

5 files changed

+153
-77
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9090
self.add_pass(ConvertMmToBmmPass())
9191
self.add_pass(DecomposeLinearPass())
9292
self.add_pass(DecomposeLinearVectorNormPass())
93-
self.add_pass(DecomposeMeanDimPass())
93+
self.add_pass(
94+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
95+
)
9496
self.add_pass(ConvertFullLikeToFullPass())
9597
self.add_pass(ConvertToClampPass())
9698
self.add_pass(ConvertMinMaxPass())
@@ -144,7 +146,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
144146
self.add_pass(DecomposeBatchNormPass())
145147
self.add_pass(DecomposeLayerNormPass())
146148
self.add_pass(DecomposeVarPass())
147-
self.add_pass(DecomposeMeanDimPass())
149+
self.add_pass(
150+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
151+
)
148152
self.add_pass(DecomposeNotEqualPass())
149153
self.add_pass(DecomposeDivPass())
150154
self.add_pass(DecomposeSoftmaxPass())
@@ -209,7 +213,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
209213
self.add_pass(ScalarsToAttributePass())
210214
self.add_pass(DecomposeLayerNormPass())
211215
self.add_pass(DecomposeVarPass())
212-
self.add_pass(DecomposeMeanDimPass())
216+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
213217
self.add_pass(DecomposeNotEqualPass())
214218
self.add_pass(DecomposeCosineSimilarityPass())
215219
self.add_pass(DecomposeDivPass())

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
11+
from executorch.backends.arm.operator_support.pool_2d_support import AvgPool2dSupported
12+
from executorch.exir.backend.utils import WhyNoPartitionReporter
1113
from executorch.exir.dialects._ops import ops as exir_ops
1214

1315

@@ -60,6 +62,14 @@ class DecomposeMeanDimPass(ArmPass):
6062
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
6163
"""
6264

65+
def __init__(self, graph_module, tosa_spec):
66+
super().__init__()
67+
self._graph_module = graph_module
68+
self._tosa_spec = tosa_spec
69+
self._avg_pool_checker = AvgPool2dSupported(
70+
self._tosa_spec, WhyNoPartitionReporter()
71+
)
72+
6373
def call_operator(self, op, args, kwargs, meta):
6474
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
6575
return super().call_operator(op, args, kwargs, meta)
@@ -86,13 +96,11 @@ def call_operator(self, op, args, kwargs, meta):
8696

8797
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
8898

89-
# Reduce (h,w) by avg pool
90-
dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2]
91-
x = self._reduce_by_average_pool(op, x, dims_to_reduce_by_avgpool, meta)
99+
# Reduce (h,w) dims by avg pool if possible
100+
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
92101

93-
# Reduce (n, c) by reduce sum
94-
dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2]
95-
x = self._reduce_by_sum(op, x, dims_to_reduce_by_sum, meta, dtype)
102+
# Reduce remaining dims by sum
103+
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
96104

97105
# Reshape to correct output shape if necessary
98106
if x.data.size() != output_shape:
@@ -116,22 +124,41 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
116124
return super().call_operator(mul_op, (sum, full), {}, meta, True)
117125

118126
def _reduce_by_average_pool(self, op, input_node, dims, meta):
119-
if len(dims) == 0:
120-
return input_node
127+
dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2]
128+
if len(dims_to_reduce_by_avgpool) == 0:
129+
return input_node, dims
130+
131+
dims_to_reduce_by_sum = [dim for dim in dims if dim < 2]
121132

122133
avgpool_op = get_avgpool(op)
123134
input_shape = input_node.data.size()
124135

125136
stride = [1, 1]
126-
if dims in ([2, 3], [3, 2]):
137+
if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]):
127138
kernel_size = [input_shape[2], input_shape[3]]
128-
elif dims == [3]:
139+
elif dims_to_reduce_by_avgpool == [3]:
129140
kernel_size = [1, input_shape[3]]
130-
elif dims == [2]:
141+
elif dims_to_reduce_by_avgpool == [2]:
131142
kernel_size = [input_shape[2], 1]
132143
else:
133-
raise RuntimeError(f"Bad dims {dims} for {op} decomposition of mean_dim.")
144+
raise RuntimeError(
145+
f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim."
146+
)
134147

135-
return super().call_operator(
136-
avgpool_op, (input_node, kernel_size, stride), {}, meta, True
148+
args = (input_node, kernel_size, stride)
149+
150+
avg_pool_node = self._graph_module.graph.create_node(
151+
"call_function", avgpool_op, args
152+
)
153+
is_supported = self._avg_pool_checker.is_node_tosa_supported(
154+
avg_pool_node, self._tosa_spec
137155
)
156+
157+
if is_supported:
158+
return (
159+
super().call_operator(avgpool_op, args, {}, meta, True),
160+
dims_to_reduce_by_sum,
161+
)
162+
163+
else:
164+
return input_node, dims

backends/arm/operator_support/pool_2d_support.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch.fx as fx
10+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1011
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1112
register_tosa_support_check,
1213
SupportedTOSAOperatorCheck,
@@ -50,7 +51,11 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5051
return True
5152

5253
# U55 case, Vela 4.2.0 (25.02 release)
53-
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
54+
input_arg = node.args[0]
55+
if isinstance(input_arg, torch.fx.Node):
56+
input_arg = get_first_fake_tensor(input_arg)
57+
shape = input_arg.data.shape # type: ignore[union-attr]
58+
5459
kernel = cast(tuple[int, int], node.args[1])
5560
stride = cast(tuple[int, int], node.args[2])
5661
if len(node.args) > 3:

backends/arm/test/ops/test_mean_dim.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,52 +91,52 @@ class MeanDim(torch.nn.Module):
9191
True,
9292
),
9393
"rank_2_keepdim": lambda: (
94-
torch.rand(7, 7),
94+
torch.rand(7, 3),
9595
(0, 1),
9696
True,
9797
),
9898
"rank_3_keepdim": lambda: (
99-
torch.rand(7, 7, 7),
99+
torch.rand(5, 7, 3),
100100
(0, 1, 2),
101101
True,
102102
),
103103
"rand_1_keepdim": lambda: (
104-
torch.rand(1, 7, 7, 7),
104+
torch.rand(1, 5, 7, 3),
105105
(1),
106106
True,
107107
),
108108
"rand_2_keepdim": lambda: (
109-
torch.rand(1, 7, 7, 7),
109+
torch.rand(1, 5, 7, 3),
110110
(2),
111111
True,
112112
),
113113
"rand_3_keepdim": lambda: (
114-
torch.rand(1, 7, 7, 7),
114+
torch.rand(1, 5, 7, 3),
115115
(3),
116116
True,
117117
),
118118
"rand_12_keepdim": lambda: (
119-
torch.rand(1, 7, 7, 7),
119+
torch.rand(1, 5, 7, 3),
120120
(1, 2),
121121
True,
122122
),
123123
"rand_13_keepdim": lambda: (
124-
torch.rand(1, 7, 7, 7),
124+
torch.rand(1, 5, 7, 3),
125125
(1, 3),
126126
True,
127127
),
128128
"rand_23_keepdim": lambda: (
129-
torch.rand(1, 7, 7, 7),
129+
torch.rand(1, 5, 7, 3),
130130
(2, 3),
131131
True,
132132
),
133133
"rand_123_keepdim": lambda: (
134-
torch.rand(1, 7, 7, 7),
134+
torch.rand(1, 5, 7, 3),
135135
(1, 2, 3),
136136
True,
137137
),
138138
"rand_0123_keepdim": lambda: (
139-
torch.rand(1, 7, 7, 7),
139+
torch.rand(1, 5, 7, 3),
140140
(0, 1, 2, 3),
141141
True,
142142
),
@@ -146,55 +146,60 @@ class MeanDim(torch.nn.Module):
146146
False,
147147
),
148148
"rank_2": lambda: (
149-
torch.rand(7, 7),
149+
torch.rand(5, 7),
150150
(-2, -1),
151151
False,
152152
),
153153
"rank_3": lambda: (
154-
torch.rand(7, 7, 7),
154+
torch.rand(5, 7, 3),
155155
(-3, -2, -1),
156156
False,
157157
),
158158
"rand_1": lambda: (
159-
torch.rand(1, 7, 7, 7),
159+
torch.rand(1, 5, 7, 3),
160160
(-3),
161161
False,
162162
),
163163
"rand_2": lambda: (
164-
torch.rand(1, 7, 7, 7),
164+
torch.rand(1, 5, 7, 3),
165165
(-2),
166166
False,
167167
),
168168
"rand_3": lambda: (
169-
torch.rand(1, 7, 7, 7),
169+
torch.rand(1, 5, 7, 3),
170170
(-1),
171171
False,
172172
),
173173
"rand_12": lambda: (
174-
torch.rand(1, 7, 7, 7),
174+
torch.rand(1, 5, 7, 3),
175175
(-3, -2),
176176
False,
177177
),
178178
"rand_13": lambda: (
179-
torch.rand(1, 7, 7, 7),
179+
torch.rand(1, 5, 7, 3),
180180
(-3, -1),
181181
False,
182182
),
183183
"rand_23": lambda: (
184-
torch.rand(1, 7, 7, 7),
184+
torch.rand(1, 5, 7, 3),
185185
(-2, -1),
186186
False,
187187
),
188188
"rand_123": lambda: (
189-
torch.rand(1, 7, 7, 7),
189+
torch.rand(1, 5, 7, 3),
190190
(-3, -2, -1),
191191
False,
192192
),
193193
"rand_0123": lambda: (
194-
torch.rand(1, 7, 7, 7),
194+
torch.rand(1, 5, 7, 3),
195195
(-4, -3, -2, -1),
196196
False,
197197
),
198+
"u55_avg_pool_not_supported": lambda: (
199+
torch.rand(1, 1, 1, 257),
200+
(0, 1, 2, 3),
201+
True,
202+
),
198203
}
199204
torch_op = "torch.ops.aten.mean.dim"
200205
exir_op = "executorch_exir_dialects_edge__ops_aten_mean_dim"
@@ -241,7 +246,13 @@ def test_mean_dim_u55_BI(test_data):
241246
[], # Might be sum, avgpool, or both
242247
run_on_fvp=True,
243248
symmetric_io_quantization=True,
244-
).dump_artifact("export")
249+
)
250+
pipeline.add_stage_after(
251+
"export",
252+
pipeline.tester.check_not,
253+
["torch.ops.aten.adaptive_avg_pool2d.default"],
254+
suffix="avg_pool",
255+
)
245256
pipeline.run()
246257

247258

0 commit comments

Comments
 (0)