Skip to content

Commit 0d201e0

Browse files
StrycekSimonrobert-kalmar
authored andcommitted
NXP backend: Add support for 'aten::hardtanh' operator conversion
1 parent 206dcf4 commit 0d201e0

File tree

8 files changed

+208
-0
lines changed

8 files changed

+208
-0
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
3232
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
3333
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
34+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
3435
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
3536
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
3637
}

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.convolution_converter import (
1111
ConvolutionConverter,
1212
)
13+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import (
14+
HardTanhConverter,
15+
)
1316
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.max_pool_2d_converter import (
1417
MaxPool2dConverter,
1518
)
@@ -48,4 +51,5 @@
4851
"ReLUConverter",
4952
"MaxPool2dConverter",
5053
"AvgPool2dConverter",
54+
"HardTanhConverter",
5155
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
8+
NodeConverter,
9+
Target,
10+
)
11+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
12+
BuiltinOperator,
13+
)
14+
from torch.fx import Node
15+
from torch.nn import Parameter
16+
17+
18+
class HardTanhConverter(NodeConverter):
19+
supported_targets = [Target.RT700]
20+
21+
# Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite.
22+
supported_modes_map = {
23+
(0.0, 6.0): BuiltinOperator.RELU6,
24+
(-1.0, 1.0): BuiltinOperator.RELU_N1_TO_1,
25+
(0.0, 1.0): BuiltinOperator.RELU_0_TO_1,
26+
(0.0, float("inf")): BuiltinOperator.RELU,
27+
}
28+
29+
@staticmethod
30+
def _is_supported_in_IR(
31+
node: Node, parameters_mapping: dict[str, Parameter]
32+
) -> bool:
33+
_, min_value, max_value = node.args
34+
return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys()
35+
36+
def convert(self, node: Node):
37+
"""Convert 'aten::hardtanh' to it's supported ReLU equivalent."""
38+
self.assert_convertible(node)
39+
40+
t_op = self._create_tflite_op_with_io_tensors(node)
41+
42+
_, min_value, max_value = node.args
43+
44+
op = self.supported_modes_map[(min_value, max_value)]
45+
t_op.opcode_index = self.builder.op_code_index_for_op_type(op)
46+
47+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
195195
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405
196196
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
197197
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
198+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
198199
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
199200
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
200201
}

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
AvgPoolPattern,
1717
Conv1dPattern,
1818
Conv2dPattern,
19+
HardTanhInPlacePattern,
20+
HardTanhPattern,
1921
LinearPattern,
2022
MaxPoolPattern,
2123
PadPattern,
@@ -199,6 +201,8 @@ def __init__(self):
199201
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
200202
NeutronAtenQuantizer(PadPattern(), static_qconfig),
201203
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
204+
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
205+
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
202206
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
203207
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
204208
NeutronAtenQuantizer(ViewPattern(), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,26 @@ def get_anchors(
216216
)
217217

218218

219+
class HardTanhPattern(SharedSpecPattern):
220+
"""
221+
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
222+
computation layer.
223+
"""
224+
225+
def partition_types(self):
226+
return [torch.ops.aten.hardtanh.default]
227+
228+
229+
class HardTanhInPlacePattern(SharedSpecPattern):
230+
"""
231+
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
232+
functions usually follows computation layer.
233+
"""
234+
235+
def partition_types(self):
236+
return [torch.ops.aten.hardtanh_.default]
237+
238+
219239
class LinearPattern(QuantizationPattern):
220240
def partition_types(self) -> List[OpOverload]:
221241
return [torch.ops.aten.linear.default]

backends/nxp/tests/executors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from executorch.backends.nxp.backend.ir import logger
1616
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
1717
from torch.export import ExportedProgram
18+
from torch.fx.graph import Graph
19+
1820

1921
# If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python
2022
# interpreter available in tflite_runtime
@@ -278,6 +280,10 @@ def convert_run_compare(
278280
return tflite_executor, edge_program_executor
279281

280282

283+
def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
284+
return any(node.target in ops for node in graph.nodes)
285+
286+
281287
class OverrideSupportedTargets:
282288

283289
def __init__(self, converter_class, *, new_targets):
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from executorch.backends.nxp.backend.edge_program_converter import (
6+
EdgeProgramToIRConverter,
7+
)
8+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import (
9+
HardTanhConverter,
10+
)
11+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
12+
from executorch.backends.nxp.tests.executors import (
13+
convert_run_compare,
14+
graph_contains_any_of_ops,
15+
ToNCHWPreprocess,
16+
ToNHWCPreprocess,
17+
)
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch.export import ExportedProgram
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def reseed_model_per_test_run():
24+
torch.manual_seed(23)
25+
np.random.seed(23)
26+
27+
28+
class Relu6ConvBlock(torch.nn.Module):
29+
def __init__(self, conv_in_channels: int = 3, inplace: bool = False):
30+
super().__init__()
31+
self.block = torch.nn.Sequential(
32+
torch.nn.Conv2d(
33+
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
34+
),
35+
torch.nn.ReLU6(inplace=inplace),
36+
)
37+
38+
def forward(self, x):
39+
return self.block(x)
40+
41+
42+
class CustomHardTanhBlock(torch.nn.Module):
43+
def __init__(
44+
self,
45+
conv_in_channels: int = 3,
46+
min_act_val: float = -1.0,
47+
max_act_val: float = 1.0,
48+
inplace: bool = False,
49+
):
50+
super().__init__()
51+
self.block = torch.nn.Sequential(
52+
torch.nn.Conv2d(
53+
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
54+
),
55+
torch.nn.Hardtanh(
56+
min_val=min_act_val, max_val=max_act_val, inplace=inplace
57+
),
58+
)
59+
60+
def forward(self, x):
61+
return self.block(x)
62+
63+
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
64+
@pytest.mark.parametrize("inplace", [True, False])
65+
def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool):
66+
# The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen.
67+
# Testing the hardtanh originated from torch.nn.Relu6 op.
68+
model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace)
69+
70+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
71+
72+
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
73+
74+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
75+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
76+
77+
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
78+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)
79+
80+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
81+
convert_run_compare(
82+
exported_program,
83+
tfl_model=tflite_flatbuffers_model,
84+
tflite_input_preprocess=ToNHWCPreprocess(),
85+
tflite_output_preprocess=ToNCHWPreprocess(),
86+
input_data=input_data,
87+
atol=1.0,
88+
)
89+
90+
91+
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
92+
@pytest.mark.parametrize(
93+
"activation_range", list(HardTanhConverter.supported_modes_map.keys())
94+
)
95+
@pytest.mark.parametrize("inplace", [True, False])
96+
def test_custom_hardtanh_quant(
97+
mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool
98+
):
99+
min_val, max_val = activation_range
100+
model = CustomHardTanhBlock(
101+
conv_in_channels=input_shape[1],
102+
min_act_val=min_val,
103+
max_act_val=max_val,
104+
inplace=inplace,
105+
)
106+
107+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
108+
109+
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
110+
111+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
112+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
113+
114+
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
115+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)
116+
117+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
118+
convert_run_compare(
119+
exported_program,
120+
tfl_model=tflite_flatbuffers_model,
121+
tflite_input_preprocess=ToNHWCPreprocess(),
122+
tflite_output_preprocess=ToNCHWPreprocess(),
123+
input_data=input_data,
124+
atol=1.0,
125+
)

0 commit comments

Comments
 (0)