diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 0b51c28cde8..755e847a3dc 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -62,6 +62,7 @@ from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 75fc529a7e1..a00dd6fe9bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -65,6 +65,7 @@ DecomposeMaxPool2dPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeQuantNodesPass, DecomposeRemainderPass, DecomposeRoundPass, DecomposeScaledDotProductAttentionPass, @@ -186,7 +187,7 @@ def _tosa_pipeline( ] ) - # Fold Q/DQ nodes, insert INT8/INT32 rescales. + # Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes. self.add_passes( [ FoldAndAnnotateQParamsPass(exported_program), @@ -197,6 +198,7 @@ def _tosa_pipeline( DecomposeLinearPass(), InsertRescaleInt32Pass(), InsertControlFlowRescalesPass(), + DecomposeQuantNodesPass(), ] ) diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 6225bf92707..737ea85a156 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -38,7 +38,9 @@ def call(self, graph_module: torch.fx.GraphModule): if not is_quantized: continue with graph.inserting_after(node): - replace_node = create_node(graph, exir_ops.edge.aten.elu.default) + replace_node = create_node( + graph, exir_ops.edge.aten.elu.default, from_node=node + ) old_args = list(node.args) alpha = old_args[1] if len(old_args) > 1 else 1.0 diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 34fcefa20e3..66da43c57b4 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -7,7 +7,10 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) from executorch.backends.arm._passes.convert_squeezes_to_view import ( ConvertSqueezesToViewPass, ) @@ -131,15 +134,21 @@ def call(self, graph_module: torch.fx.GraphModule): for dim in dims: args = (input_node, dim, True) - input_node = graph_module.graph.create_node( - "call_function", op, args, node.kwargs + input_node = create_node( + graph=graph_module.graph, + op_target=op, + args=args, + kwargs={}, + from_node=node, ) if not keepdims: - input_node = graph_module.graph.create_node( - "call_function", - squeeze_op, - (input_node, dims), + input_node = create_node( + graph=graph_module.graph, + op_target=squeeze_op, + args=(input_node, dims), + kwargs={}, + from_node=node, ) replace_node.replace_all_uses_with(input_node) diff --git a/backends/arm/_passes/decompose_quant_nodes.py b/backends/arm/_passes/decompose_quant_nodes.py new file mode 100644 index 00000000000..3cc99e7baca --- /dev/null +++ b/backends/arm/_passes/decompose_quant_nodes.py @@ -0,0 +1,156 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_round_pass import DecomposeRoundPass +from executorch.backends.arm.constants import DEQUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeQuantNodesPass(ArmPass): + """Decomposes quantization nodes into more primitive operations by rewriting the graph + using the two formulas: + + quantized value = clamp(round(fp32_value / scale) + zero point, qmin, qmax) + + fp32_value = (quantized value - zp) * scale + + For quantization nodes, the pass replaces them with: + + 1. Multiplying the input by the inverse of the scale factor. + 2. Rounding the result. + 3. Adding the zero point. + 4. Clamping the result to [qmin, qmax]. + 5. Casting to the target data type. + + For dequantization nodes, the pass replaces them with: + + 1. Casting the input to int32. + 2. Subtracting the zero point. + 3. Casting to float32. + 4. Multiplying by the scale factor. + + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeRoundPass} + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in ( + QUANT_PER_TENSOR_OP, + DEQUANT_PER_TENSOR_OP, + ): + continue + if node.target == DEQUANT_PER_TENSOR_OP and all( + user.target == QUANT_PER_TENSOR_OP for user in node.users + ): + continue + elif ( + node.target == QUANT_PER_TENSOR_OP + and node.all_input_nodes[0].target == DEQUANT_PER_TENSOR_OP + ): + continue + modified = True + args = node.args + input_rank = args[0].meta["val"].ndim + x, scale, zero_point, qmin, qmax, dtype = args + # Instead of dividing by scale in quantization, we multiply by 1/scale + # when quantizing. + scale = cast(float, scale) + scale = scale if node.target == DEQUANT_PER_TENSOR_OP else 1.0 / scale + with graph_module.graph.inserting_before(node): + scale_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, scale), + kwargs={"dtype": torch.float32}, + ) + zp_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, zero_point), + kwargs={ + "dtype": ( + torch.float32 + if node.target == QUANT_PER_TENSOR_OP + else torch.int32 + ) + }, + ) + if node.target == QUANT_PER_TENSOR_OP: + # TODO MLETORCH-1587: Decompose quantization nodes using more integer arithmetic + scaled = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(x, scale_const), + from_node=node, + ) + rounded = create_node( + graph_module.graph, + exir_ops.edge.aten.round.default, + args=(scaled,), + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.add.Tensor, + args=(rounded, zp_const), + from_node=node, + ) + clamped = create_node( + graph_module.graph, + exir_ops.edge.aten.clamp.default, + args=(shifted, float(qmin), float(qmax)), + from_node=node, + ) + quantized = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(clamped,), + kwargs={"dtype": dtype}, + from_node=node, + ) + output = quantized + else: + input_casted_to_zp_dtype = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(x,), + kwargs={"dtype": torch.int32}, + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.sub.Tensor, + args=(input_casted_to_zp_dtype, zp_const), + from_node=node, + ) + casted_to_float = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(shifted,), + kwargs={"dtype": torch.float32}, + from_node=node, + ) + dequantized = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(casted_to_float, scale_const), + from_node=node, + ) + output = dequantized + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified=modified) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 8815a47b18c..d2e32f27660 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,10 +15,14 @@ is_param_node, set_node_arg, ) +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir import ExportedProgram @@ -230,15 +233,37 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): submodule.graph.erase_node(node_to_remove) return + @staticmethod + def is_foldable(node: Node) -> bool: + if node.op != "call_function": + return False + # Don't fold chains of quant-ops into each other. + if node.target in (*Q_OPS, *DQ_OPS): + return False + + # Always fold q-dq into constant ops. + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOTPass.targeted_ops, + ): + return True + + # We should not fold q-dq nodes into non-quantized nodes. + if not ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and ArmAnnotationInfo( + node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY] + ).quantized + ): + return False + return True + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: n = cast(Node, n) - if n.op != "call_function": - continue - # Don't fold chains of quant-ops into each other. - if n.target in (*Q_OPS, *DQ_OPS): + if not FoldAndAnnotateQParamsPass.is_foldable(n): continue # Make sure we haven't already set qparams meta information on the node diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index ade287a0cee..27de85e5ba9 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -235,8 +235,8 @@ def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: if node.op != "call_function" or node not in self.table_ops: continue - input_qparams = node.meta["input_qparams"] - output_qparams = node.meta["output_qparams"] + input_qparams = node.meta.get("input_qparams", {}) + output_qparams = node.meta.get("output_qparams", {}) if len(input_qparams) == 0 or len(output_qparams) == 0: # We only want to replace the node if it's quantized continue diff --git a/backends/arm/test/passes/test_decompose_quant_nodes.py b/backends/arm/test/passes/test_decompose_quant_nodes.py new file mode 100644 index 00000000000..fe216164f86 --- /dev/null +++ b/backends/arm/test/passes/test_decompose_quant_nodes.py @@ -0,0 +1,44 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import DecomposeQuantNodesPass +from executorch.backends.arm.test.common import parametrize +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class Mul(torch.nn.Module): + test_data = { + "randn": (torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + "large_randn": (10e10 * torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + } + + def forward(self, x, y): + return x * y + + +@parametrize("test_data", Mul.test_data) +def test_decompose_quant_nodes_pass(test_data: Tuple[torch.Tensor]): + module = Mul() + q_dq_ops = { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + # Verify that DecomposeQuantNodesPass removes quantize/dequantize nodes + # and that the output is correct. + pipeline = PassPipeline( + module, + test_data, + quantize=True, + pass_list=[ + DecomposeQuantNodesPass, + ], + ops_before_pass=q_dq_ops, + ops_not_after_pass=list(q_dq_ops.keys()), + tosa_extensions=["FP"], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index dcf945d5bb4..2015ab61834 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -20,7 +20,7 @@ class SimpleQuantizeModel(torch.nn.Module): } def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + torch.max((x + x), (y + y)) + return x + torch.maximum((x + x), (y + y)) @common.parametrize("test_data", SimpleQuantizeModel.test_data) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 8bfc32049ed..b86372fa360 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.operator_configs import ( @@ -19,6 +20,7 @@ ) from executorch.backends.cortex_m.quantizer.quantization_configs import ( INT8_PER_TENSOR_CONFIG, + QuantizationSpec, ) from torch._ops import OpOverload from torch.fx import GraphModule, Node @@ -31,6 +33,20 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +def mark_node_as_annotated( + node: Node, + input_qspec_map: dict[Node, Optional[QuantizationSpec]], + output_qspec: Optional[QuantizationSpec], +) -> None: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(input_qspec_map, output_qspec) + annotation_info = ArmAnnotationInfo( + quantized=True, + ) + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom + + class CortexMQuantizer(ComposableQuantizer): def broadcasting_filter(self, node: Optional[Node]) -> bool: @@ -211,9 +227,7 @@ def annotate_match( if all(node not in match for node in node.users) and output_qspec is None: output_qspec = config.output_activation if config else None - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, output_qspec - ) + mark_node_as_annotated(node, input_qspec_map, output_qspec) def annotate(self, model: GraphModule) -> None: matches = self.match_patterns(model, self.operator_config.operators) @@ -242,8 +256,8 @@ def annotate(self, model: GraphModule) -> None: is_placeholder = node.op == "placeholder" is_filtered = self.filter_fn(node) if is_placeholder and not is_filtered: - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - {}, self.quantization_config.output_activation + mark_node_as_annotated( + node, {}, self.quantization_config.output_activation ) def validate(self, model: GraphModule) -> bool: @@ -271,9 +285,7 @@ def annotate(self, model: GraphModule) -> None: if not self.filter_fn(n) } output_qspec = self.quantization_config.output_activation - output_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, output_qspec - ) + mark_node_as_annotated(output_node, input_qspec_map, output_qspec) def validate(self, model: GraphModule) -> bool: return True @@ -378,10 +390,10 @@ def _annotate_shared_cluster(self, root_node: Node) -> None: shared_qspec = SharedQuantizationSpec(shared_root_node) for node in shared_nodes: - input_qspec_map = {n: shared_qspec for n in node.all_input_nodes} - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, shared_qspec - ) + input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { + n: shared_qspec for n in node.all_input_nodes + } + mark_node_as_annotated(node, input_qspec_map, shared_qspec) def annotate(self, model: GraphModule) -> None: for node in model.graph.nodes: diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 8c355fd2e39..429557ae1f4 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -139,7 +139,7 @@ class CortexMAlphaAdd(ModelAlpha): } -xfails = { +xfails_implementation = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -152,13 +152,17 @@ class CortexMAlphaAdd(ModelAlpha): "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, ), - "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", +} +xfails_dialect = xfails_implementation | { + # Cortex-M quantizer will not quantize additions that require broadcasting + # leading to the add op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=xfails) +@parametrize("test_case", test_cases, xfails=xfails_dialect) def test_dialect_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -166,7 +170,7 @@ def test_dialect_add(test_case): ) -@parametrize("test_case", test_cases, xfails=xfails) +@parametrize("test_case", test_cases, xfails=xfails_implementation) def test_implementation_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index 8a67d1b7de1..5630abbdab3 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -185,7 +185,6 @@ def forward(self, x): "conv2d_dilation": "NotImplementedError: 'slow_conv_dilated<>' not implemented for 'Int'", "conv1d": "Currently not supported.", "conv2d_nchw": "Currently not supported.", - "conv3d": "Currently not supported.", } @@ -201,7 +200,6 @@ def test_dialect_conv2d(test_case): xfails_implementation = { "conv1d": "Currently not supported.", - "conv2d_nchw": "Currently not supported.", "conv3d": "Currently not supported.", } @@ -209,4 +207,4 @@ def test_dialect_conv2d(test_case): @parametrize("test_case", test_cases, xfails=xfails_implementation) def test_implementation_conv2d(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation(qtol=1) + tester.test_implementation(qtol=2) diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index 35c958ce8d4..8d32a5df92a 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -114,7 +114,7 @@ class CortexMTensorMul(Model): } -xfail_cases = { +xfail_cases_implementation = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -123,13 +123,17 @@ class CortexMTensorMul(Model): "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, ), - "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", +} +xfail_cases_dialect = xfail_cases_implementation | { + # Cortex-M quantizer will not quantize multiplicaitons that require broadcasting + # leading to the mul op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=xfail_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases_dialect) def test_dialect_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -139,7 +143,7 @@ def test_dialect_mul(test_case): ) -@parametrize("test_case", test_cases, xfails=xfail_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases_implementation) def test_implementation_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation(qtol=1) diff --git a/backends/transforms/remove_getitem_op.py b/backends/transforms/remove_getitem_op.py index 733393b6d9a..bb08eb4ed25 100644 --- a/backends/transforms/remove_getitem_op.py +++ b/backends/transforms/remove_getitem_op.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -77,6 +78,8 @@ def call(self, graph_module: torch.fx.GraphModule): args=node.args, kwargs=node.kwargs, ) + new_max_wd.meta = node.meta.copy() + new_max_wd.meta["val"] = new_max_wd.meta["val"][0] getitem_node.replace_all_uses_with(new_max_wd)