1
1
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2
- # All rights reserved.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
6
5
7
- # pyre-unsafe
6
+ from math import prod
8
7
9
8
import torch
10
9
from executorch .backends .arm ._passes import ArmPass
@@ -28,42 +27,111 @@ def get_meandim_decomposition(op) -> tuple:
28
27
raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
29
28
30
29
30
+ def get_avgpool (op ):
31
+ if op == exir_ops .edge .aten .mean .dim :
32
+ return exir_ops .edge .aten .avg_pool2d .default
33
+ if op == torch .ops .aten .mean .dim :
34
+ return torch .ops .aten .avg_pool2d .default
35
+ raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
36
+
37
+
38
+ def get_view (op ):
39
+ if op == exir_ops .edge .aten .mean .dim :
40
+ return exir_ops .edge .aten .view_copy .default
41
+ if op == torch .ops .aten .mean .dim :
42
+ return torch .ops .aten .view_copy .default
43
+ raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
44
+
45
+
31
46
class DecomposeMeanDimPass (ArmPass ):
32
47
"""
33
- This pass decomposes meandim into a sum and mul node.
48
+ Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
49
+ h,w -> avg_pool
50
+ n,c -> sum + mul(1/N)
51
+ For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left.
34
52
35
53
Example:
36
- y = mean_dim(x, dim, keepdim)
54
+ x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w )
37
55
Becomes:
38
- sum = sum.dim_IntList(x, dim, keepdim)
39
- y = mul(sum, 1/N)
56
+ x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
57
+ x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
58
+ x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
59
+ x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
60
+ x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
40
61
"""
41
62
42
63
def call_operator (self , op , args , kwargs , meta ):
43
64
if op not in (exir_ops .edge .aten .mean .dim , torch .ops .aten .mean .dim ):
44
65
return super ().call_operator (op , args , kwargs , meta )
45
66
46
67
x = get_node_arg (args , 0 )
47
- dim = get_node_arg (args , 1 )
48
- keepdim = get_node_arg (args , 2 , False )
49
-
50
- # if dim == [-1, -2], mean.dim can be
51
- # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
52
- if dim == [- 1 , - 2 ]:
53
- # Simply return the mean.dim operator for future decomposition.
54
- return super ().call_operator (op , args , kwargs , meta )
68
+ input_shape = x .data .size ()
69
+ output_shape = meta ["val" ].size ()
70
+ dims_to_reduce = get_node_arg (args , 1 )
71
+ dims_to_reduce = [dim % len (input_shape ) for dim in dims_to_reduce ]
55
72
56
- shape = meta ["val" ].size ()
57
73
dtype = meta ["val" ].dtype
58
- input_shape = x .data .size ()
59
- N = 1
60
- for d in dim :
61
- N *= input_shape [d ]
74
+ view_op = get_view (op )
62
75
76
+ if len (input_shape ) > 4 :
77
+ raise NotImplementedError (
78
+ f"{ op } with rank > 4 is currently not supported for the TOSA backend."
79
+ )
80
+
81
+ # Unsqueeze to 4D
82
+ if len (input_shape ) < 4 :
83
+ pad_n = 4 - len (input_shape )
84
+ new_shape = [1 ] * pad_n + list (input_shape )
85
+ dims_to_reduce = [dim + pad_n for dim in dims_to_reduce ]
86
+
87
+ x = super ().call_operator (view_op , (x , new_shape ), {}, meta , True )
88
+
89
+ # Reduce (h,w) by avg pool
90
+ dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2 ]
91
+ x = self ._reduce_by_average_pool (op , x , dims_to_reduce_by_avgpool , meta )
92
+
93
+ # Reduce (n, c) by reduce sum
94
+ dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2 ]
95
+ x = self ._reduce_by_sum (op , x , dims_to_reduce_by_sum , meta , dtype )
96
+
97
+ # Reshape to correct output shape if necessary
98
+ if x .data .size () != output_shape :
99
+ x = super ().call_operator (view_op , (x , output_shape ), {}, meta , True )
100
+
101
+ return x
102
+
103
+ def _reduce_by_sum (self , op , input_node , dims , meta , dtype ):
104
+ if len (dims ) == 0 :
105
+ return input_node
106
+
107
+ input_shape = input_node .data .size ()
108
+ output_shape = meta ["val" ].size ()
109
+ N = prod ((n for i , n in enumerate (input_shape ) if i in dims ))
63
110
sum_op , full_op , mul_op = get_meandim_decomposition (op )
64
111
65
- sum = super ().call_operator (sum_op , (x , dim , keepdim ), {}, meta , True )
112
+ sum = super ().call_operator (sum_op , (input_node , dims , True ), {}, meta , True )
66
113
full = super ().call_operator (
67
- full_op , ([1 ] * len (shape ), 1 / N ), {"dtype" : dtype }, meta , True
114
+ full_op , ([1 ] * len (output_shape ), 1 / N ), {"dtype" : dtype }, meta , True
68
115
)
69
116
return super ().call_operator (mul_op , (sum , full ), {}, meta , True )
117
+
118
+ def _reduce_by_average_pool (self , op , input_node , dims , meta ):
119
+ if len (dims ) == 0 :
120
+ return input_node
121
+
122
+ avgpool_op = get_avgpool (op )
123
+ input_shape = input_node .data .size ()
124
+
125
+ stride = [1 , 1 ]
126
+ if dims in ([2 , 3 ], [3 , 2 ]):
127
+ kernel_size = [input_shape [2 ], input_shape [3 ]]
128
+ elif dims == [3 ]:
129
+ kernel_size = [1 , input_shape [3 ]]
130
+ elif dims == [2 ]:
131
+ kernel_size = [input_shape [2 ], 1 ]
132
+ else :
133
+ raise RuntimeError (f"Bad dims { dims } for { op } decomposition of mean_dim." )
134
+
135
+ return super ().call_operator (
136
+ avgpool_op , (input_node , kernel_size , stride ), {}, meta , True
137
+ )
0 commit comments