diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 1d20e1db4fe..c0f28cc3d87 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -16,7 +16,6 @@ import torch import torch.fx -import torch.nn.functional as F from executorch.backends.arm.common.debug import get_node_debug_info from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.quantizer import QuantizationConfig @@ -477,7 +476,11 @@ def get_quant_properties( # noqa: C901 def any_or_hardtanh_min_zero(n: Node): """Return True for any op or hardtanh with ``min_val == 0``.""" # Check that if the node is a hardtanh, its min_val is zero - return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0 + return ( + n.target + not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default) + or n.args[1] == 0 + ) if _match_pattern( node, @@ -487,11 +490,14 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.padding, ], - [torch.ops.aten.batch_norm.default, F.batch_norm], + [ + torch.ops.aten.batch_norm.default, + ], [ torch.ops.aten.relu.default, torch.ops.aten.relu_.default, torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, ], ], filter_fn=any_or_hardtanh_min_zero, @@ -510,6 +516,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.relu.default, torch.ops.aten.relu_.default, torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, ): quant_properties.quant_output = _QuantProperty(0, output_act_qspec) @@ -521,7 +528,9 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.padding, ], - [torch.ops.aten.batch_norm.default, F.batch_norm], + [ + torch.ops.aten.batch_norm.default, + ], ], ): if node.target in ( @@ -534,7 +543,9 @@ def any_or_hardtanh_min_zero(n: Node): _QuantProperty(1, weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] - elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]: + elif node.target in [ + torch.ops.aten.batch_norm.default, + ]: quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif _match_pattern( node, @@ -549,6 +560,7 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.relu.default, torch.ops.aten.relu_.default, torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, ], ], any_or_hardtanh_min_zero, diff --git a/backends/arm/test/misc/test_bn_relu_folding_qat.py b/backends/arm/test/misc/test_bn_relu_folding_qat.py index c88c38e869d..f2452c348f6 100644 --- a/backends/arm/test/misc/test_bn_relu_folding_qat.py +++ b/backends/arm/test/misc/test_bn_relu_folding_qat.py @@ -6,13 +6,13 @@ from typing import Tuple import torch -import torch.nn.functional as F from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, ) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.xnnpack.test.tester.tester import Quantize from torch import nn @@ -21,43 +21,97 @@ input_t1 = Tuple[torch.Tensor] # Input x -class ConvModule(torch.nn.Module): +class Conv2dModule(torch.nn.Module): input_shape = (1, 28, 28) batch_size = 64 test_data: input_t1 = (torch.randn(batch_size, *input_shape),) - def __init__(self, batch_norm: bool = True) -> None: + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 16, 3, stride=2) self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity() + self.relu = nn.ReLU(inplace=inplace) def forward(self, x: torch.Tensor): x = self.conv(x) x = self.bn(x) - x = F.relu(x) + x = self.relu(x) + + return x + + +class Conv1dModule(torch.nn.Module): + input_shape = (3, 10) + batch_size = 2 + test_data: input_t1 = (torch.randn(batch_size, *input_shape),) + + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: + super().__init__() + self.conv = torch.nn.Conv1d(3, 8, 5, padding=2) + self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity() + self.relu = nn.ReLU(inplace=inplace) + + def forward(self, x: torch.Tensor): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) return x models = { # name : (model, is_per_channel) - "conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True), - "conv_relu_per_channel": (ConvModule(batch_norm=False), True), - "conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False), - "conv_relu_per_tensor": (ConvModule(batch_norm=False), False), + "conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True), + "conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True), + "conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False), + "conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False), + "conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True), + "conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True), + "conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False), + "conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False), + "conv1d_bn_relu_inplace_per_channel": ( + Conv1dModule(batch_norm=True, inplace=True), + True, + ), + "conv1d_relu_inplace_per_channel": ( + Conv1dModule(batch_norm=False, inplace=True), + True, + ), + "conv1d_bn_relu_inplace_per_tensor": ( + Conv1dModule(batch_norm=True, inplace=True), + False, + ), + "conv1d_relu_inplace_per_tensor": ( + Conv1dModule(batch_norm=False, inplace=True), + False, + ), + "conv2d_bn_relu_inplace_per_channel": ( + Conv2dModule(batch_norm=True, inplace=True), + True, + ), + "conv2d_relu_inplace_per_channel": ( + Conv2dModule(batch_norm=False, inplace=True), + True, + ), + "conv2d_bn_relu_inplace_per_tensor": ( + Conv2dModule(batch_norm=True, inplace=True), + False, + ), + "conv2d_relu_inplace_per_tensor": ( + Conv2dModule(batch_norm=False, inplace=True), + False, + ), } -@common.parametrize("test_data", models) +@common.parametrize( + "test_data", + models, +) def test_qat_tosa_INT(test_data): model, per_channel = test_data pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1) - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"), - } - tosa_spec = tosa_profiles[tosa_version] - quantizer = TOSAQuantizer(tosa_spec) + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) pipeline.change_args( "quantize", Quantize( @@ -65,7 +119,6 @@ def test_qat_tosa_INT(test_data): quantization_config=get_symmetric_quantization_config( is_qat=True, is_per_channel=per_channel ), - is_qat=True, ), ) pipeline.run()