Skip to content

NXP backend: Add support for 'aten::hardtanh' operator #12339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.convolution_converter import (
ConvolutionConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import (
HardTanhConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.max_pool_2d_converter import (
MaxPool2dConverter,
)
Expand Down Expand Up @@ -48,4 +51,5 @@
"ReLUConverter",
"MaxPool2dConverter",
"AvgPool2dConverter",
"HardTanhConverter",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.converter.node_converter import (
NodeConverter,
Target,
)
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
BuiltinOperator,
)
from torch.fx import Node
from torch.nn import Parameter


class HardTanhConverter(NodeConverter):
supported_targets = [Target.RT700]

# Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite.
supported_modes_map = {
(0.0, 6.0): BuiltinOperator.RELU6,
(-1.0, 1.0): BuiltinOperator.RELU_N1_TO_1,
(0.0, 1.0): BuiltinOperator.RELU_0_TO_1,
(0.0, float("inf")): BuiltinOperator.RELU,
}

@staticmethod
def _is_supported_in_IR(
node: Node, parameters_mapping: dict[str, Parameter]
) -> bool:
_, min_value, max_value = node.args
return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys()

def convert(self, node: Node):
"""Convert 'aten::hardtanh' to it's supported ReLU equivalent."""
self.assert_convertible(node)

t_op = self._create_tflite_op_with_io_tensors(node)

_, min_value, max_value = node.args

op = self.supported_modes_map[(min_value, max_value)]
t_op.opcode_index = self.builder.op_code_index_for_op_type(op)

self.builder.append_operators([t_op])
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
}
Expand Down
4 changes: 4 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
AvgPoolPattern,
Conv1dPattern,
Conv2dPattern,
HardTanhInPlacePattern,
HardTanhPattern,
LinearPattern,
MaxPoolPattern,
PadPattern,
Expand Down Expand Up @@ -199,6 +201,8 @@ def __init__(self):
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
Expand Down
20 changes: 20 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,26 @@ def get_anchors(
)


class HardTanhPattern(SharedSpecPattern):
"""
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
computation layer.
"""

def partition_types(self):
return [torch.ops.aten.hardtanh.default]


class HardTanhInPlacePattern(SharedSpecPattern):
"""
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
functions usually follows computation layer.
"""

def partition_types(self):
return [torch.ops.aten.hardtanh_.default]


class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]
Expand Down
6 changes: 6 additions & 0 deletions backends/nxp/tests/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from executorch.backends.nxp.backend.ir import logger
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from torch.export import ExportedProgram
from torch.fx.graph import Graph


# If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python
# interpreter available in tflite_runtime
Expand Down Expand Up @@ -278,6 +280,10 @@ def convert_run_compare(
return tflite_executor, edge_program_executor


def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
return any(node.target in ops for node in graph.nodes)


class OverrideSupportedTargets:

def __init__(self, converter_class, *, new_targets):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import numpy as np
import pytest
import torch

from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import (
HardTanhConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
ToNCHWPreprocess,
ToNHWCPreprocess,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
torch.manual_seed(23)
np.random.seed(23)


class Relu6ConvBlock(torch.nn.Module):
def __init__(self, conv_in_channels: int = 3, inplace: bool = False):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
),
torch.nn.ReLU6(inplace=inplace),
)

def forward(self, x):
return self.block(x)


class CustomHardTanhBlock(torch.nn.Module):
def __init__(
self,
conv_in_channels: int = 3,
min_act_val: float = -1.0,
max_act_val: float = 1.0,
inplace: bool = False,
):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
),
torch.nn.Hardtanh(
min_val=min_act_val, max_val=max_act_val, inplace=inplace
),
)

def forward(self, x):
return self.block(x)


@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
@pytest.mark.parametrize("inplace", [True, False])
def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool):
# The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen.
# Testing the hardtanh originated from torch.nn.Relu6 op.
model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

quantized_program = to_quantized_edge_program(model, input_shape).exported_program()

tflite_flatbuffers_model, io_formats = converter_spy.spy_return
exported_program: ExportedProgram = converter_spy.call_args.args[1]

ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)

input_data = (np.random.random(input_shape) * 50).astype(np.int8)
convert_run_compare(
exported_program,
tfl_model=tflite_flatbuffers_model,
tflite_input_preprocess=ToNHWCPreprocess(),
tflite_output_preprocess=ToNCHWPreprocess(),
input_data=input_data,
atol=1.0,
)


@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
@pytest.mark.parametrize(
"activation_range", list(HardTanhConverter.supported_modes_map.keys())
)
@pytest.mark.parametrize("inplace", [True, False])
def test_custom_hardtanh_quant(
mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool
):
min_val, max_val = activation_range
model = CustomHardTanhBlock(
conv_in_channels=input_shape[1],
min_act_val=min_val,
max_act_val=max_val,
inplace=inplace,
)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

quantized_program = to_quantized_edge_program(model, input_shape).exported_program()

tflite_flatbuffers_model, io_formats = converter_spy.spy_return
exported_program: ExportedProgram = converter_spy.call_args.args[1]

ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)

input_data = (np.random.random(input_shape) * 50).astype(np.int8)
convert_run_compare(
exported_program,
tfl_model=tflite_flatbuffers_model,
tflite_input_preprocess=ToNHWCPreprocess(),
tflite_output_preprocess=ToNCHWPreprocess(),
input_data=input_data,
atol=1.0,
)
Loading