Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
import torch.fx
import torch.nn.functional as F
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.quantizer import QuantizationConfig
Expand Down Expand Up @@ -477,7 +476,11 @@ def get_quant_properties( # noqa: C901
def any_or_hardtanh_min_zero(n: Node):
"""Return True for any op or hardtanh with ``min_val == 0``."""
# Check that if the node is a hardtanh, its min_val is zero
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
return (
n.target
not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default)
or n.args[1] == 0
)

if _match_pattern(
node,
Expand All @@ -487,11 +490,14 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
],
[torch.ops.aten.batch_norm.default, F.batch_norm],
[
torch.ops.aten.batch_norm.default,
],
[
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
],
],
filter_fn=any_or_hardtanh_min_zero,
Expand All @@ -510,6 +516,7 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
):
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

Expand All @@ -521,7 +528,9 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
],
[torch.ops.aten.batch_norm.default, F.batch_norm],
[
torch.ops.aten.batch_norm.default,
],
],
):
if node.target in (
Expand All @@ -534,7 +543,9 @@ def any_or_hardtanh_min_zero(n: Node):
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
elif node.target in [
torch.ops.aten.batch_norm.default,
]:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif _match_pattern(
node,
Expand All @@ -549,6 +560,7 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
],
],
any_or_hardtanh_min_zero,
Expand Down
87 changes: 70 additions & 17 deletions backends/arm/test/misc/test_bn_relu_folding_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import Tuple

import torch
import torch.nn.functional as F
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
from executorch.backends.arm.tosa import TosaSpecification

from executorch.backends.xnnpack.test.tester.tester import Quantize
from torch import nn
Expand All @@ -21,51 +21,104 @@
input_t1 = Tuple[torch.Tensor] # Input x


class ConvModule(torch.nn.Module):
class Conv2dModule(torch.nn.Module):
input_shape = (1, 28, 28)
batch_size = 64
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)

def __init__(self, batch_norm: bool = True) -> None:
def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()
self.relu = nn.ReLU(inplace=inplace)

def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
x = self.relu(x)

return x


class Conv1dModule(torch.nn.Module):
input_shape = (3, 10)
batch_size = 2
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)

def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(3, 8, 5, padding=2)
self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity()
self.relu = nn.ReLU(inplace=inplace)

def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)

return x


models = {
# name : (model, is_per_channel)
"conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True),
"conv_relu_per_channel": (ConvModule(batch_norm=False), True),
"conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False),
"conv_relu_per_tensor": (ConvModule(batch_norm=False), False),
"conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True),
"conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True),
"conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False),
"conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False),
"conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True),
"conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True),
"conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False),
"conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False),
"conv1d_bn_relu_inplace_per_channel": (
Conv1dModule(batch_norm=True, inplace=True),
True,
),
"conv1d_relu_inplace_per_channel": (
Conv1dModule(batch_norm=False, inplace=True),
True,
),
"conv1d_bn_relu_inplace_per_tensor": (
Conv1dModule(batch_norm=True, inplace=True),
False,
),
"conv1d_relu_inplace_per_tensor": (
Conv1dModule(batch_norm=False, inplace=True),
False,
),
"conv2d_bn_relu_inplace_per_channel": (
Conv2dModule(batch_norm=True, inplace=True),
True,
),
"conv2d_relu_inplace_per_channel": (
Conv2dModule(batch_norm=False, inplace=True),
True,
),
"conv2d_bn_relu_inplace_per_tensor": (
Conv2dModule(batch_norm=True, inplace=True),
False,
),
"conv2d_relu_inplace_per_tensor": (
Conv2dModule(batch_norm=False, inplace=True),
False,
),
}


@common.parametrize("test_data", models)
@common.parametrize(
"test_data",
models,
)
def test_qat_tosa_INT(test_data):
model, per_channel = test_data
pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1)
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
}
tosa_spec = tosa_profiles[tosa_version]
quantizer = TOSAQuantizer(tosa_spec)
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
pipeline.change_args(
"quantize",
Quantize(
quantizer=quantizer,
quantization_config=get_symmetric_quantization_config(
is_qat=True, is_per_channel=per_channel
),
is_qat=True,
),
)
pipeline.run()