6
6
import numpy as np
7
7
import torch
8
8
9
- from executorch .backends .nxp .backend .edge_helper import input_tensor , input_tensor_safe
9
+ from executorch .backends .nxp .backend .edge_helper import (
10
+ input_tensor ,
11
+ input_tensor_safe ,
12
+ node_is_effectively_static_tensor ,
13
+ )
10
14
from executorch .backends .nxp .backend .ir .converter .conversion import (
11
15
aten_translator ,
12
16
common ,
13
17
)
14
- from executorch .backends .nxp .backend .ir .converter .conversion .common import (
15
- OpsList ,
16
- try_get_input ,
17
- )
18
+ from executorch .backends .nxp .backend .ir .converter .conversion .common import try_get_input
18
19
from executorch .backends .nxp .backend .ir .converter .node_converter import (
19
20
NodeConverter ,
20
21
Target ,
21
22
)
23
+ from executorch .backends .nxp .backend .ir .converter .node_converters .shared import (
24
+ conv_utils ,
25
+ )
26
+ from executorch .backends .nxp .backend .ir .converter .node_converters .shared .conv_utils import (
27
+ ConvConversionResult ,
28
+ ConvParameters ,
29
+ )
22
30
from executorch .backends .nxp .backend .ir .converter .quantization_utils import (
23
31
set_quantization_parameters_to_tensor ,
24
32
)
33
+ from executorch .backends .nxp .backend .ir .converter .tensor_utils import tensor_has_data
25
34
from executorch .backends .nxp .backend .ir .lib .tflite .TensorType import TensorType
26
35
from executorch .backends .nxp .backend .ir .tflite_generator import tflite_model
27
36
from executorch .backends .nxp .backend .ir .tflite_generator .builtin_options import (
28
37
conv_2d_options ,
38
+ depthwise_conv_2d_options ,
29
39
)
30
40
from torch .fx import Node
31
41
from torch .nn import Parameter
@@ -48,7 +58,29 @@ def _is_supported_in_IR(
48
58
if output_padding != [0 , 0 ]:
49
59
return False
50
60
51
- if groups != 1 :
61
+ if groups == 1 :
62
+ # Regular (pointwise) convolution.
63
+ pass
64
+
65
+ elif conv_utils .group_conv_convertible_as_depthwise (
66
+ node , groups
67
+ ) and node_is_effectively_static_tensor (node .args [1 ], parameters_mapping ):
68
+ # Depthwise convolution.
69
+ # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted weights. In case
70
+ # the weights are dynamic, a Transpose operator would have to be added, which is not supported on Neutron.
71
+ pass
72
+
73
+ elif conv_utils .group_conv_convertible_into_multiple_convolutions (node , groups ):
74
+ # Group Separable convolution.
75
+ # Not supported natively by the eIQ Neutron so Group Separable Convolution.
76
+ # In practice it can be computed by splitting the Group Separable Convolution into multiple Pointwise
77
+ # Convo it will use the Split and Concat operation. The Concat operation in Neutron Converter
78
+ # SDK 25.03 requires the # of channels to be multipy of # of MAC units in the eIQ Neutron.
79
+ # For this reason Group Separable Convolution is not delegated by default at this moment.
80
+ return False
81
+
82
+ else :
83
+ # All conversion options related to the `group` attribute have been checked and none of them can be used.
52
84
return False
53
85
54
86
if input_tensor_safe (node , 2 ) is None :
@@ -57,71 +89,152 @@ def _is_supported_in_IR(
57
89
if weight_tensor .dtype not in [torch .float32 , torch .int8 , torch .uint8 ]:
58
90
return False
59
91
60
- return True
61
-
62
- def _convert_2d_conv (
63
- self , stride , padding , dilation , t_op : tflite_model .Operator
64
- ) -> list [tflite_model .Operator ]:
65
- ops = OpsList (middle_op = t_op )
66
- t_op .builtin_options = conv_2d_options .Conv2D ()
67
- common .assign_2d_strides (t_op .builtin_options , stride )
68
- common .assign_2d_dilations (t_op .builtin_options , dilation )
69
- t_op .builtin_options .padding , explicit_padding = (
70
- aten_translator .convert_padding (padding )
71
- )
92
+ if node .args [0 ].meta ["val" ].shape [0 ] != 1 :
93
+ # Only batch size 1 is supported on neutron.
94
+ return False
72
95
73
- if explicit_padding is not None :
74
- # Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
75
- ops .add_pre (
76
- self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
77
- )
96
+ return True
78
97
79
- input_tensor : tflite_model . Tensor = t_op . tmp_inputs [ 0 ]
80
- weight_tensor : tflite_model . Tensor = t_op . tmp_inputs [ 1 ]
81
- output_tensor : tflite_model . Tensor = t_op . tmp_outputs [ 0 ]
98
+ Stride = Padding = Dilation = OutPadding = list [ int ]
99
+ Transposed = bool
100
+ Groups = int
82
101
83
- if (bias_tensor := try_get_input (t_op , 2 )) is None :
102
+ @staticmethod
103
+ def _get_convolution_arguments (
104
+ conv_node : Node ,
105
+ ) -> (Stride , Padding , Dilation , Transposed , OutPadding , Groups ):
106
+ # The arguments of the conv are:
107
+ # [x, w, b, stride, padding, dilation, transposed, output padding, groups]
108
+ # https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
109
+ _ , _ , _ , stride , padding , dilation , transposed , out_padding , groups = (
110
+ conv_node .args
111
+ )
112
+ return stride , padding , dilation , transposed , out_padding , groups
113
+
114
+ # noinspection PyPep8Naming
115
+ def _convert_unpadded_2D (
116
+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
117
+ ) -> conv_utils .ConvConversionResult :
118
+ """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the
119
+ caller.
120
+ """
121
+ common .assign_2d_strides (t_op .builtin_options , conv_params .stride )
122
+ common .assign_2d_dilations (t_op .builtin_options , conv_params .dilation )
123
+
124
+ x : tflite_model .Tensor = t_op .tmp_inputs [0 ]
125
+ w : tflite_model .Tensor = t_op .tmp_inputs [1 ]
126
+ y : tflite_model .Tensor = t_op .tmp_outputs [0 ]
127
+
128
+ if (b := try_get_input (t_op , 2 )) is None :
84
129
# Operator has no bias. Convolution aten op can omit it, TFLite can't.
85
- output_channels = weight_tensor .shape .vector [0 ]
130
+ output_channels = w .shape .vector [0 ]
86
131
87
- if weight_tensor .type == TensorType .FLOAT32 :
132
+ if w .type == TensorType .FLOAT32 :
88
133
bias_type = np .dtype (np .float32 )
89
- elif weight_tensor .type in [TensorType .INT8 , TensorType .UINT8 ]:
134
+ elif w .type in [TensorType .INT8 , TensorType .UINT8 ]:
90
135
bias_type = np .dtype (np .int32 )
91
136
else :
92
137
# Should never happen.
93
138
raise NotImplementedError (
94
- f"Convolution node with unsupported weight type: { weight_tensor .type } "
139
+ f"Convolution node with unsupported weight type: { w .type } "
95
140
)
96
141
97
- bias_tensor = self .builder .create_zeros_tensor (
142
+ b = self .builder .create_zeros_tensor (
98
143
[output_channels ], "zero_bias" , bias_type , True
99
144
)
100
145
101
146
# Compute scale and zero point for bias tensor
102
- input_scale = np .array (input_tensor .quantization .scale .vector )
103
- weight_scale = np .array (weight_tensor .quantization .scale .vector )
147
+ input_scale = np .array (x .quantization .scale .vector )
148
+ weight_scale = np .array (w .quantization .scale .vector )
104
149
bias_scale = input_scale * weight_scale
105
150
bias_zero_point = np .zeros (weight_scale .shape , dtype = np .int64 )
106
151
107
152
set_quantization_parameters_to_tensor (
108
- bias_tensor , bias_scale , bias_zero_point , quantized_dimension = 0
153
+ b , bias_scale , bias_zero_point , quantized_dimension = 0
109
154
)
110
155
111
156
# Assign the operator its TFLite inputs and outputs
112
- t_op .tmp_inputs = [input_tensor , weight_tensor , bias_tensor ]
113
- t_op .tmp_outputs = [output_tensor ]
157
+ t_op .tmp_inputs = [x , w , b ]
158
+ t_op .tmp_outputs = [y ]
159
+
160
+ conversion_result = ConvConversionResult (x , w , b , y )
161
+ conversion_result .ops_list .middle_op = t_op
162
+
163
+ return conversion_result
164
+
165
+ def _convert_2d_conv (
166
+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
167
+ ) -> list [tflite_model .Operator ]:
168
+ if conv_utils .group_conv_convertible_as_depthwise (
169
+ t_op , conv_params .groups
170
+ ): # Convert to `DepthwiseConv2D`.
171
+ t_op .builtin_options = depthwise_conv_2d_options .DepthwiseConv2D ()
172
+
173
+ conversion_result = self ._convert_unpadded_2D (t_op , conv_params )
174
+ t_op .builtin_options .padding , explicit_padding = (
175
+ aten_translator .convert_padding (conv_params .padding )
176
+ )
177
+ if explicit_padding is not None :
178
+ # Need to prepend a 'Pad' operator, which adds 0s.
179
+ conversion_result .ops_list .add_pre (
180
+ self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
181
+ )
182
+
183
+ # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
184
+ perm = [3 , 1 , 2 , 0 ]
185
+ weight_tensor = conversion_result .conv_weight_tensor
186
+ if tensor_has_data (weight_tensor ):
187
+ # Transpose cloned tensor statically
188
+ t_op .tmp_inputs [1 ] = self .builder .create_transposed_tensor (
189
+ weight_tensor , perm
190
+ )
191
+ else :
192
+ raise NotImplementedError ("Dynamic Depthwise Conv weights." )
193
+
194
+ elif conv_utils .group_conv_convertible_into_multiple_convolutions (
195
+ t_op , conv_params .groups
196
+ ):
197
+ # Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the
198
+ # ConvolutionConveter._is_supported_in_IR()
199
+ t_op .builtin_options = conv_2d_options .Conv2D ()
200
+
201
+ return conv_utils .create_separated_convolutions_based_on_group (
202
+ t_op ,
203
+ conv_params ,
204
+ self .builder ,
205
+ self ._convert_unpadded_2D ,
206
+ conv_utils .conv_op_factory ,
207
+ )
208
+
209
+ else :
210
+ # Convert to regular `Conv2D`.
211
+ t_op .builtin_options = conv_2d_options .Conv2D ()
212
+ conversion_result = self ._convert_unpadded_2D (t_op , conv_params )
213
+ t_op .builtin_options .padding , explicit_padding = (
214
+ aten_translator .convert_padding (conv_params .padding )
215
+ )
216
+ if explicit_padding is not None :
217
+ # Need to prepend a 'Pad' operator, which adds 0s.
218
+ conversion_result .ops_list .add_pre (
219
+ self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
220
+ )
114
221
115
- return ops .flatten ()
222
+ return conversion_result . ops_list .flatten ()
116
223
117
224
def convert (self , node : Node ):
118
225
self .assert_convertible (node )
119
226
120
- stride = node .args [3 ]
121
- padding = node .args [4 ]
122
- dilation = node .args [5 ]
227
+ stride , padding , dilation , _ , _ , groups = self ._get_convolution_arguments (node )
123
228
124
229
t_op = self ._create_tflite_op_with_io_tensors (node )
125
- ops_to_add = self ._convert_2d_conv (stride , padding , dilation , t_op )
230
+ conv_params = ConvParameters (stride , padding , dilation , groups )
231
+
232
+ rank = t_op .tmp_inputs [1 ].shape .len ()
233
+ if rank == 4 : # Conv2D
234
+ ops_to_add = self ._convert_2d_conv (t_op , conv_params )
235
+ else :
236
+ raise NotImplementedError (
237
+ f"{ rank - 2 } D convolution is not supported."
238
+ ) # Should never get here.
126
239
127
240
self .builder .append_operators (ops_to_add )
0 commit comments