-
Notifications
You must be signed in to change notification settings - Fork 619
NXP backend: Add support for 'aten::hardtanh' operator #12339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robert-kalmar
merged 1 commit into
pytorch:main
from
nxp-upstream:upstream/main-nxp/EIEX-367-upstream-aten-hardtanh-operator
Jul 16, 2025
+209
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) 2025 NXP | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from executorch.backends.nxp.backend.ir.converter.node_converter import ( | ||
NodeConverter, | ||
Target, | ||
) | ||
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( | ||
BuiltinOperator, | ||
) | ||
from torch.fx import Node | ||
from torch.nn import Parameter | ||
|
||
|
||
class HardTanhConverter(NodeConverter): | ||
supported_targets = [Target.RT700] | ||
|
||
# Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite. | ||
supported_modes_map = { | ||
(0.0, 6.0): BuiltinOperator.RELU6, | ||
(-1.0, 1.0): BuiltinOperator.RELU_N1_TO_1, | ||
(0.0, 1.0): BuiltinOperator.RELU_0_TO_1, | ||
(0.0, float("inf")): BuiltinOperator.RELU, | ||
} | ||
|
||
@staticmethod | ||
def _is_supported_in_IR( | ||
node: Node, parameters_mapping: dict[str, Parameter] | ||
) -> bool: | ||
_, min_value, max_value = node.args | ||
return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys() | ||
|
||
def convert(self, node: Node): | ||
"""Convert 'aten::hardtanh' to it's supported ReLU equivalent.""" | ||
self.assert_convertible(node) | ||
|
||
t_op = self._create_tflite_op_with_io_tensors(node) | ||
|
||
_, min_value, max_value = node.args | ||
|
||
op = self.supported_modes_map[(min_value, max_value)] | ||
t_op.opcode_index = self.builder.op_code_index_for_op_type(op) | ||
|
||
self.builder.append_operators([t_op]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from executorch.backends.nxp.backend.edge_program_converter import ( | ||
EdgeProgramToIRConverter, | ||
) | ||
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import ( | ||
HardTanhConverter, | ||
) | ||
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program | ||
from executorch.backends.nxp.tests.executors import ( | ||
convert_run_compare, | ||
graph_contains_any_of_ops, | ||
ToNCHWPreprocess, | ||
ToNHWCPreprocess, | ||
) | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from torch.export import ExportedProgram | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def reseed_model_per_test_run(): | ||
torch.manual_seed(23) | ||
np.random.seed(23) | ||
|
||
|
||
class Relu6ConvBlock(torch.nn.Module): | ||
def __init__(self, conv_in_channels: int = 3, inplace: bool = False): | ||
super().__init__() | ||
self.block = torch.nn.Sequential( | ||
torch.nn.Conv2d( | ||
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4) | ||
), | ||
torch.nn.ReLU6(inplace=inplace), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class CustomHardTanhBlock(torch.nn.Module): | ||
def __init__( | ||
self, | ||
conv_in_channels: int = 3, | ||
min_act_val: float = -1.0, | ||
max_act_val: float = 1.0, | ||
inplace: bool = False, | ||
): | ||
super().__init__() | ||
self.block = torch.nn.Sequential( | ||
torch.nn.Conv2d( | ||
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4) | ||
), | ||
torch.nn.Hardtanh( | ||
min_val=min_act_val, max_val=max_act_val, inplace=inplace | ||
), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)]) | ||
@pytest.mark.parametrize("inplace", [True, False]) | ||
def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): | ||
# The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen. | ||
# Testing the hardtanh originated from torch.nn.Relu6 op. | ||
model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace) | ||
|
||
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") | ||
|
||
quantized_program = to_quantized_edge_program(model, input_shape).exported_program() | ||
|
||
tflite_flatbuffers_model, io_formats = converter_spy.spy_return | ||
exported_program: ExportedProgram = converter_spy.call_args.args[1] | ||
|
||
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] | ||
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) | ||
|
||
input_data = (np.random.random(input_shape) * 50).astype(np.int8) | ||
convert_run_compare( | ||
exported_program, | ||
tfl_model=tflite_flatbuffers_model, | ||
tflite_input_preprocess=ToNHWCPreprocess(), | ||
tflite_output_preprocess=ToNCHWPreprocess(), | ||
input_data=input_data, | ||
atol=1.0, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)]) | ||
@pytest.mark.parametrize( | ||
"activation_range", list(HardTanhConverter.supported_modes_map.keys()) | ||
) | ||
@pytest.mark.parametrize("inplace", [True, False]) | ||
def test_custom_hardtanh_quant( | ||
mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool | ||
): | ||
min_val, max_val = activation_range | ||
model = CustomHardTanhBlock( | ||
conv_in_channels=input_shape[1], | ||
min_act_val=min_val, | ||
max_act_val=max_val, | ||
inplace=inplace, | ||
) | ||
|
||
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") | ||
|
||
quantized_program = to_quantized_edge_program(model, input_shape).exported_program() | ||
|
||
tflite_flatbuffers_model, io_formats = converter_spy.spy_return | ||
exported_program: ExportedProgram = converter_spy.call_args.args[1] | ||
|
||
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] | ||
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) | ||
|
||
input_data = (np.random.random(input_shape) * 50).astype(np.int8) | ||
convert_run_compare( | ||
exported_program, | ||
tfl_model=tflite_flatbuffers_model, | ||
tflite_input_preprocess=ToNHWCPreprocess(), | ||
tflite_output_preprocess=ToNCHWPreprocess(), | ||
input_data=input_data, | ||
atol=1.0, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.