Skip to content

Commit a8d7298

Browse files
NXP backend: Add support for depthwise and separable convolution. (#11215)
Co-authored-by: Martin Pavella <[email protected]>
1 parent bc33d47 commit a8d7298

File tree

8 files changed

+924
-58
lines changed

8 files changed

+924
-58
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch.fx import Node
8+
from torch.nn import Parameter
89

910

1011
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
@@ -38,3 +39,35 @@ def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None:
3839
return None
3940

4041
return input_tensor(node, input_index)
42+
43+
44+
def node_is_static_tensor(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
45+
"""Return `True` if the given `node` has static data in the `parameters_mapping` dict.
46+
:param node: Tensor node to check for data.
47+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
48+
`state_dict` attribute of an edge program.
49+
"""
50+
return node.name in parameters_mapping.keys()
51+
52+
53+
def node_is_effectively_static_tensor(
54+
node: Node, parameters_mapping: dict[str, Parameter]
55+
) -> bool:
56+
"""Return `True` if the given `node` has static data, or follows after a `Dequantize` node with a static input.
57+
In the IR, the `node` will be turned into a static quantized tensor.
58+
:param node: Tensor node to check for data.
59+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
60+
`state_dict` attribute of an edge program.
61+
"""
62+
if node_is_static_tensor(node, parameters_mapping):
63+
return True
64+
65+
def _is_dequantize(node_: Node) -> bool:
66+
return node_.target.__name__ in {
67+
"quantized_decomposed.dequantize_per_tensor.default",
68+
"quantized_decomposed.dequantize_per_channel.default",
69+
}
70+
71+
return _is_dequantize(node) and node_is_static_tensor(
72+
node.args[0], parameters_mapping
73+
)

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 155 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,36 @@
66
import numpy as np
77
import torch
88

9-
from executorch.backends.nxp.backend.edge_helper import input_tensor, input_tensor_safe
9+
from executorch.backends.nxp.backend.edge_helper import (
10+
input_tensor,
11+
input_tensor_safe,
12+
node_is_effectively_static_tensor,
13+
)
1014
from executorch.backends.nxp.backend.ir.converter.conversion import (
1115
aten_translator,
1216
common,
1317
)
14-
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
15-
OpsList,
16-
try_get_input,
17-
)
18+
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
1819
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1920
NodeConverter,
2021
Target,
2122
)
23+
from executorch.backends.nxp.backend.ir.converter.node_converters.shared import (
24+
conv_utils,
25+
)
26+
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import (
27+
ConvConversionResult,
28+
ConvParameters,
29+
)
2230
from executorch.backends.nxp.backend.ir.converter.quantization_utils import (
2331
set_quantization_parameters_to_tensor,
2432
)
33+
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
2534
from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType
2635
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
2736
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
2837
conv_2d_options,
38+
depthwise_conv_2d_options,
2939
)
3040
from torch.fx import Node
3141
from torch.nn import Parameter
@@ -48,7 +58,29 @@ def _is_supported_in_IR(
4858
if output_padding != [0, 0]:
4959
return False
5060

51-
if groups != 1:
61+
if groups == 1:
62+
# Regular (pointwise) convolution.
63+
pass
64+
65+
elif conv_utils.group_conv_convertible_as_depthwise(
66+
node, groups
67+
) and node_is_effectively_static_tensor(node.args[1], parameters_mapping):
68+
# Depthwise convolution.
69+
# Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted weights. In case
70+
# the weights are dynamic, a Transpose operator would have to be added, which is not supported on Neutron.
71+
pass
72+
73+
elif conv_utils.group_conv_convertible_into_multiple_convolutions(node, groups):
74+
# Group Separable convolution.
75+
# Not supported natively by the eIQ Neutron so Group Separable Convolution.
76+
# In practice it can be computed by splitting the Group Separable Convolution into multiple Pointwise
77+
# Convo it will use the Split and Concat operation. The Concat operation in Neutron Converter
78+
# SDK 25.03 requires the # of channels to be multipy of # of MAC units in the eIQ Neutron.
79+
# For this reason Group Separable Convolution is not delegated by default at this moment.
80+
return False
81+
82+
else:
83+
# All conversion options related to the `group` attribute have been checked and none of them can be used.
5284
return False
5385

5486
if input_tensor_safe(node, 2) is None:
@@ -57,71 +89,152 @@ def _is_supported_in_IR(
5789
if weight_tensor.dtype not in [torch.float32, torch.int8, torch.uint8]:
5890
return False
5991

60-
return True
61-
62-
def _convert_2d_conv(
63-
self, stride, padding, dilation, t_op: tflite_model.Operator
64-
) -> list[tflite_model.Operator]:
65-
ops = OpsList(middle_op=t_op)
66-
t_op.builtin_options = conv_2d_options.Conv2D()
67-
common.assign_2d_strides(t_op.builtin_options, stride)
68-
common.assign_2d_dilations(t_op.builtin_options, dilation)
69-
t_op.builtin_options.padding, explicit_padding = (
70-
aten_translator.convert_padding(padding)
71-
)
92+
if node.args[0].meta["val"].shape[0] != 1:
93+
# Only batch size 1 is supported on neutron.
94+
return False
7295

73-
if explicit_padding is not None:
74-
# Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
75-
ops.add_pre(
76-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
77-
)
96+
return True
7897

79-
input_tensor: tflite_model.Tensor = t_op.tmp_inputs[0]
80-
weight_tensor: tflite_model.Tensor = t_op.tmp_inputs[1]
81-
output_tensor: tflite_model.Tensor = t_op.tmp_outputs[0]
98+
Stride = Padding = Dilation = OutPadding = list[int]
99+
Transposed = bool
100+
Groups = int
82101

83-
if (bias_tensor := try_get_input(t_op, 2)) is None:
102+
@staticmethod
103+
def _get_convolution_arguments(
104+
conv_node: Node,
105+
) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups):
106+
# The arguments of the conv are:
107+
# [x, w, b, stride, padding, dilation, transposed, output padding, groups]
108+
# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
109+
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
110+
conv_node.args
111+
)
112+
return stride, padding, dilation, transposed, out_padding, groups
113+
114+
# noinspection PyPep8Naming
115+
def _convert_unpadded_2D(
116+
self, t_op: tflite_model.Operator, conv_params: ConvParameters
117+
) -> conv_utils.ConvConversionResult:
118+
"""Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the
119+
caller.
120+
"""
121+
common.assign_2d_strides(t_op.builtin_options, conv_params.stride)
122+
common.assign_2d_dilations(t_op.builtin_options, conv_params.dilation)
123+
124+
x: tflite_model.Tensor = t_op.tmp_inputs[0]
125+
w: tflite_model.Tensor = t_op.tmp_inputs[1]
126+
y: tflite_model.Tensor = t_op.tmp_outputs[0]
127+
128+
if (b := try_get_input(t_op, 2)) is None:
84129
# Operator has no bias. Convolution aten op can omit it, TFLite can't.
85-
output_channels = weight_tensor.shape.vector[0]
130+
output_channels = w.shape.vector[0]
86131

87-
if weight_tensor.type == TensorType.FLOAT32:
132+
if w.type == TensorType.FLOAT32:
88133
bias_type = np.dtype(np.float32)
89-
elif weight_tensor.type in [TensorType.INT8, TensorType.UINT8]:
134+
elif w.type in [TensorType.INT8, TensorType.UINT8]:
90135
bias_type = np.dtype(np.int32)
91136
else:
92137
# Should never happen.
93138
raise NotImplementedError(
94-
f"Convolution node with unsupported weight type: {weight_tensor.type}"
139+
f"Convolution node with unsupported weight type: {w.type}"
95140
)
96141

97-
bias_tensor = self.builder.create_zeros_tensor(
142+
b = self.builder.create_zeros_tensor(
98143
[output_channels], "zero_bias", bias_type, True
99144
)
100145

101146
# Compute scale and zero point for bias tensor
102-
input_scale = np.array(input_tensor.quantization.scale.vector)
103-
weight_scale = np.array(weight_tensor.quantization.scale.vector)
147+
input_scale = np.array(x.quantization.scale.vector)
148+
weight_scale = np.array(w.quantization.scale.vector)
104149
bias_scale = input_scale * weight_scale
105150
bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64)
106151

107152
set_quantization_parameters_to_tensor(
108-
bias_tensor, bias_scale, bias_zero_point, quantized_dimension=0
153+
b, bias_scale, bias_zero_point, quantized_dimension=0
109154
)
110155

111156
# Assign the operator its TFLite inputs and outputs
112-
t_op.tmp_inputs = [input_tensor, weight_tensor, bias_tensor]
113-
t_op.tmp_outputs = [output_tensor]
157+
t_op.tmp_inputs = [x, w, b]
158+
t_op.tmp_outputs = [y]
159+
160+
conversion_result = ConvConversionResult(x, w, b, y)
161+
conversion_result.ops_list.middle_op = t_op
162+
163+
return conversion_result
164+
165+
def _convert_2d_conv(
166+
self, t_op: tflite_model.Operator, conv_params: ConvParameters
167+
) -> list[tflite_model.Operator]:
168+
if conv_utils.group_conv_convertible_as_depthwise(
169+
t_op, conv_params.groups
170+
): # Convert to `DepthwiseConv2D`.
171+
t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D()
172+
173+
conversion_result = self._convert_unpadded_2D(t_op, conv_params)
174+
t_op.builtin_options.padding, explicit_padding = (
175+
aten_translator.convert_padding(conv_params.padding)
176+
)
177+
if explicit_padding is not None:
178+
# Need to prepend a 'Pad' operator, which adds 0s.
179+
conversion_result.ops_list.add_pre(
180+
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
181+
)
182+
183+
# DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
184+
perm = [3, 1, 2, 0]
185+
weight_tensor = conversion_result.conv_weight_tensor
186+
if tensor_has_data(weight_tensor):
187+
# Transpose cloned tensor statically
188+
t_op.tmp_inputs[1] = self.builder.create_transposed_tensor(
189+
weight_tensor, perm
190+
)
191+
else:
192+
raise NotImplementedError("Dynamic Depthwise Conv weights.")
193+
194+
elif conv_utils.group_conv_convertible_into_multiple_convolutions(
195+
t_op, conv_params.groups
196+
):
197+
# Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the
198+
# ConvolutionConveter._is_supported_in_IR()
199+
t_op.builtin_options = conv_2d_options.Conv2D()
200+
201+
return conv_utils.create_separated_convolutions_based_on_group(
202+
t_op,
203+
conv_params,
204+
self.builder,
205+
self._convert_unpadded_2D,
206+
conv_utils.conv_op_factory,
207+
)
208+
209+
else:
210+
# Convert to regular `Conv2D`.
211+
t_op.builtin_options = conv_2d_options.Conv2D()
212+
conversion_result = self._convert_unpadded_2D(t_op, conv_params)
213+
t_op.builtin_options.padding, explicit_padding = (
214+
aten_translator.convert_padding(conv_params.padding)
215+
)
216+
if explicit_padding is not None:
217+
# Need to prepend a 'Pad' operator, which adds 0s.
218+
conversion_result.ops_list.add_pre(
219+
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
220+
)
114221

115-
return ops.flatten()
222+
return conversion_result.ops_list.flatten()
116223

117224
def convert(self, node: Node):
118225
self.assert_convertible(node)
119226

120-
stride = node.args[3]
121-
padding = node.args[4]
122-
dilation = node.args[5]
227+
stride, padding, dilation, _, _, groups = self._get_convolution_arguments(node)
123228

124229
t_op = self._create_tflite_op_with_io_tensors(node)
125-
ops_to_add = self._convert_2d_conv(stride, padding, dilation, t_op)
230+
conv_params = ConvParameters(stride, padding, dilation, groups)
231+
232+
rank = t_op.tmp_inputs[1].shape.len()
233+
if rank == 4: # Conv2D
234+
ops_to_add = self._convert_2d_conv(t_op, conv_params)
235+
else:
236+
raise NotImplementedError(
237+
f"{rank - 2}D convolution is not supported."
238+
) # Should never get here.
126239

127240
self.builder.append_operators(ops_to_add)

0 commit comments

Comments
 (0)