diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 488703db120..18dd746b41b 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -31,6 +31,7 @@ exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 + exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 } diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py index 7ed81272091..e9058f6f4ba 100755 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py @@ -10,6 +10,9 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.convolution_converter import ( ConvolutionConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import ( + HardTanhConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.max_pool_2d_converter import ( MaxPool2dConverter, ) @@ -48,4 +51,5 @@ "ReLUConverter", "MaxPool2dConverter", "AvgPool2dConverter", + "HardTanhConverter", ] diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py new file mode 100644 index 00000000000..53f493f4ed9 --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/hardtanh_converter.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + NodeConverter, + Target, +) +from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( + BuiltinOperator, +) +from torch.fx import Node +from torch.nn import Parameter + + +class HardTanhConverter(NodeConverter): + supported_targets = [Target.RT700] + + # Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite. + supported_modes_map = { + (0.0, 6.0): BuiltinOperator.RELU6, + (-1.0, 1.0): BuiltinOperator.RELU_N1_TO_1, + (0.0, 1.0): BuiltinOperator.RELU_0_TO_1, + (0.0, float("inf")): BuiltinOperator.RELU, + } + + @staticmethod + def _is_supported_in_IR( + node: Node, parameters_mapping: dict[str, Parameter] + ) -> bool: + _, min_value, max_value = node.args + return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys() + + def convert(self, node: Node): + """Convert 'aten::hardtanh' to it's supported ReLU equivalent.""" + self.assert_convertible(node) + + t_op = self._create_tflite_op_with_io_tensors(node) + + _, min_value, max_value = node.args + + op = self.supported_modes_map[(min_value, max_value)] + t_op.opcode_index = self.builder.op_code_index_for_op_type(op) + + self.builder.append_operators([t_op]) diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 44863a6344e..1f31ae6c82e 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -195,6 +195,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 + exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 } diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index f467b66c9d5..c8ebb8a1966 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -16,6 +16,8 @@ AvgPoolPattern, Conv1dPattern, Conv2dPattern, + HardTanhInPlacePattern, + HardTanhPattern, LinearPattern, MaxPoolPattern, PadPattern, @@ -199,6 +201,8 @@ def __init__(self): NeutronAtenQuantizer(PermutePattern(), static_qconfig), NeutronAtenQuantizer(PadPattern(), static_qconfig), NeutronAtenQuantizer(ReluPattern(), static_qconfig), + NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), + NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig), NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), NeutronAtenQuantizer(ViewPattern(), static_qconfig), diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 89252f0e75d..331516dc7f7 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -216,6 +216,26 @@ def get_anchors( ) +class HardTanhPattern(SharedSpecPattern): + """ + Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows + computation layer. + """ + + def partition_types(self): + return [torch.ops.aten.hardtanh.default] + + +class HardTanhInPlacePattern(SharedSpecPattern): + """ + Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation + functions usually follows computation layer. + """ + + def partition_types(self): + return [torch.ops.aten.hardtanh_.default] + + class LinearPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.linear.default] diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index 2c9fdf69f5a..38aa0ae460c 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -15,6 +15,8 @@ from executorch.backends.nxp.backend.ir import logger from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from torch.export import ExportedProgram +from torch.fx.graph import Graph + # If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python # interpreter available in tflite_runtime @@ -278,6 +280,10 @@ def convert_run_compare( return tflite_executor, edge_program_executor +def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool: + return any(node.target in ops for node in graph.nodes) + + class OverrideSupportedTargets: def __init__(self, converter_class, *, new_targets): diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py new file mode 100644 index 00000000000..f90118f4bed --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -0,0 +1,126 @@ +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import ( + HardTanhConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, + ToNCHWPreprocess, + ToNHWCPreprocess, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +class Relu6ConvBlock(torch.nn.Module): + def __init__(self, conv_in_channels: int = 3, inplace: bool = False): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4) + ), + torch.nn.ReLU6(inplace=inplace), + ) + + def forward(self, x): + return self.block(x) + + +class CustomHardTanhBlock(torch.nn.Module): + def __init__( + self, + conv_in_channels: int = 3, + min_act_val: float = -1.0, + max_act_val: float = 1.0, + inplace: bool = False, + ): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4) + ), + torch.nn.Hardtanh( + min_val=min_act_val, max_val=max_act_val, inplace=inplace + ), + ) + + def forward(self, x): + return self.block(x) + + +@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)]) +@pytest.mark.parametrize("inplace", [True, False]) +def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): + # The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen. + # Testing the hardtanh originated from torch.nn.Relu6 op. + model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] + assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + input_data=input_data, + atol=1.0, + ) + + +@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)]) +@pytest.mark.parametrize( + "activation_range", list(HardTanhConverter.supported_modes_map.keys()) +) +@pytest.mark.parametrize("inplace", [True, False]) +def test_custom_hardtanh_quant( + mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool +): + min_val, max_val = activation_range + model = CustomHardTanhBlock( + conv_in_channels=input_shape[1], + min_act_val=min_val, + max_act_val=max_val, + inplace=inplace, + ) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default] + assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops) + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + input_data=input_data, + atol=1.0, + )