From a79c9eb5f68c9c767a1c46a84c6d955076170482 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Fri, 26 Sep 2025 10:02:49 +0200 Subject: [PATCH 1/3] Arm backend: Add support for q-dq-decomposition For mixed type models we need to be able switch between FP and INT, meaning quantize and dequantize online. As there's no quantize or dequantize operator in TOSA, we need to decompose these operators to TOSA operators. This commit introduces a pass that decomposes q-dq nodes into more primitive nodes. Signed-off-by: Oscar Andersson Change-Id: I05d52572a8d614f9f77119f031915ef8bf2a00e3 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/decompose_quant_nodes.py | 156 ++++++++++++++++++ .../test/passes/test_decompose_quant_nodes.py | 44 +++++ 3 files changed, 201 insertions(+) create mode 100644 backends/arm/_passes/decompose_quant_nodes.py create mode 100644 backends/arm/test/passes/test_decompose_quant_nodes.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 0b51c28cde8..755e847a3dc 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -62,6 +62,7 @@ from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa diff --git a/backends/arm/_passes/decompose_quant_nodes.py b/backends/arm/_passes/decompose_quant_nodes.py new file mode 100644 index 00000000000..3cc99e7baca --- /dev/null +++ b/backends/arm/_passes/decompose_quant_nodes.py @@ -0,0 +1,156 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_round_pass import DecomposeRoundPass +from executorch.backends.arm.constants import DEQUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeQuantNodesPass(ArmPass): + """Decomposes quantization nodes into more primitive operations by rewriting the graph + using the two formulas: + + quantized value = clamp(round(fp32_value / scale) + zero point, qmin, qmax) + + fp32_value = (quantized value - zp) * scale + + For quantization nodes, the pass replaces them with: + + 1. Multiplying the input by the inverse of the scale factor. + 2. Rounding the result. + 3. Adding the zero point. + 4. Clamping the result to [qmin, qmax]. + 5. Casting to the target data type. + + For dequantization nodes, the pass replaces them with: + + 1. Casting the input to int32. + 2. Subtracting the zero point. + 3. Casting to float32. + 4. Multiplying by the scale factor. + + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeRoundPass} + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in ( + QUANT_PER_TENSOR_OP, + DEQUANT_PER_TENSOR_OP, + ): + continue + if node.target == DEQUANT_PER_TENSOR_OP and all( + user.target == QUANT_PER_TENSOR_OP for user in node.users + ): + continue + elif ( + node.target == QUANT_PER_TENSOR_OP + and node.all_input_nodes[0].target == DEQUANT_PER_TENSOR_OP + ): + continue + modified = True + args = node.args + input_rank = args[0].meta["val"].ndim + x, scale, zero_point, qmin, qmax, dtype = args + # Instead of dividing by scale in quantization, we multiply by 1/scale + # when quantizing. + scale = cast(float, scale) + scale = scale if node.target == DEQUANT_PER_TENSOR_OP else 1.0 / scale + with graph_module.graph.inserting_before(node): + scale_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, scale), + kwargs={"dtype": torch.float32}, + ) + zp_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, zero_point), + kwargs={ + "dtype": ( + torch.float32 + if node.target == QUANT_PER_TENSOR_OP + else torch.int32 + ) + }, + ) + if node.target == QUANT_PER_TENSOR_OP: + # TODO MLETORCH-1587: Decompose quantization nodes using more integer arithmetic + scaled = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(x, scale_const), + from_node=node, + ) + rounded = create_node( + graph_module.graph, + exir_ops.edge.aten.round.default, + args=(scaled,), + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.add.Tensor, + args=(rounded, zp_const), + from_node=node, + ) + clamped = create_node( + graph_module.graph, + exir_ops.edge.aten.clamp.default, + args=(shifted, float(qmin), float(qmax)), + from_node=node, + ) + quantized = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(clamped,), + kwargs={"dtype": dtype}, + from_node=node, + ) + output = quantized + else: + input_casted_to_zp_dtype = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(x,), + kwargs={"dtype": torch.int32}, + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.sub.Tensor, + args=(input_casted_to_zp_dtype, zp_const), + from_node=node, + ) + casted_to_float = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(shifted,), + kwargs={"dtype": torch.float32}, + from_node=node, + ) + dequantized = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(casted_to_float, scale_const), + from_node=node, + ) + output = dequantized + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified=modified) diff --git a/backends/arm/test/passes/test_decompose_quant_nodes.py b/backends/arm/test/passes/test_decompose_quant_nodes.py new file mode 100644 index 00000000000..fe216164f86 --- /dev/null +++ b/backends/arm/test/passes/test_decompose_quant_nodes.py @@ -0,0 +1,44 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import DecomposeQuantNodesPass +from executorch.backends.arm.test.common import parametrize +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class Mul(torch.nn.Module): + test_data = { + "randn": (torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + "large_randn": (10e10 * torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + } + + def forward(self, x, y): + return x * y + + +@parametrize("test_data", Mul.test_data) +def test_decompose_quant_nodes_pass(test_data: Tuple[torch.Tensor]): + module = Mul() + q_dq_ops = { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + # Verify that DecomposeQuantNodesPass removes quantize/dequantize nodes + # and that the output is correct. + pipeline = PassPipeline( + module, + test_data, + quantize=True, + pass_list=[ + DecomposeQuantNodesPass, + ], + ops_before_pass=q_dq_ops, + ops_not_after_pass=list(q_dq_ops.keys()), + tosa_extensions=["FP"], + ) + pipeline.run() From 2b2c6ac4680d00b4a2270a360126315467bdab64 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 6 Oct 2025 14:41:29 +0200 Subject: [PATCH 2/3] Arm backend: Enable q-dq decomposition Enable q-dq decomposition in ArmPassManager. Signed-off-by: Oscar Andersson Change-Id: Idc0d502c2ef1e6081abe55637ba21caad6c847d2 --- backends/arm/_passes/arm_pass_manager.py | 4 ++- backends/arm/_passes/convert_elu_params.py | 4 ++- backends/arm/_passes/convert_minmax_pass.py | 23 ++++++++---- .../fold_qdq_with_annotated_qparams_pass.py | 35 ++++++++++++++++--- backends/arm/_passes/insert_table_ops.py | 4 +-- .../arm/test/passes/test_fold_qdq_pass.py | 2 +- backends/transforms/remove_getitem_op.py | 3 ++ 7 files changed, 58 insertions(+), 17 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 75fc529a7e1..a00dd6fe9bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -65,6 +65,7 @@ DecomposeMaxPool2dPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeQuantNodesPass, DecomposeRemainderPass, DecomposeRoundPass, DecomposeScaledDotProductAttentionPass, @@ -186,7 +187,7 @@ def _tosa_pipeline( ] ) - # Fold Q/DQ nodes, insert INT8/INT32 rescales. + # Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes. self.add_passes( [ FoldAndAnnotateQParamsPass(exported_program), @@ -197,6 +198,7 @@ def _tosa_pipeline( DecomposeLinearPass(), InsertRescaleInt32Pass(), InsertControlFlowRescalesPass(), + DecomposeQuantNodesPass(), ] ) diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 6225bf92707..737ea85a156 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -38,7 +38,9 @@ def call(self, graph_module: torch.fx.GraphModule): if not is_quantized: continue with graph.inserting_after(node): - replace_node = create_node(graph, exir_ops.edge.aten.elu.default) + replace_node = create_node( + graph, exir_ops.edge.aten.elu.default, from_node=node + ) old_args = list(node.args) alpha = old_args[1] if len(old_args) > 1 else 1.0 diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 34fcefa20e3..66da43c57b4 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -7,7 +7,10 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) from executorch.backends.arm._passes.convert_squeezes_to_view import ( ConvertSqueezesToViewPass, ) @@ -131,15 +134,21 @@ def call(self, graph_module: torch.fx.GraphModule): for dim in dims: args = (input_node, dim, True) - input_node = graph_module.graph.create_node( - "call_function", op, args, node.kwargs + input_node = create_node( + graph=graph_module.graph, + op_target=op, + args=args, + kwargs={}, + from_node=node, ) if not keepdims: - input_node = graph_module.graph.create_node( - "call_function", - squeeze_op, - (input_node, dims), + input_node = create_node( + graph=graph_module.graph, + op_target=squeeze_op, + args=(input_node, dims), + kwargs={}, + from_node=node, ) replace_node.replace_all_uses_with(input_node) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 8815a47b18c..d2e32f27660 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,10 +15,14 @@ is_param_node, set_node_arg, ) +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir import ExportedProgram @@ -230,15 +233,37 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): submodule.graph.erase_node(node_to_remove) return + @staticmethod + def is_foldable(node: Node) -> bool: + if node.op != "call_function": + return False + # Don't fold chains of quant-ops into each other. + if node.target in (*Q_OPS, *DQ_OPS): + return False + + # Always fold q-dq into constant ops. + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOTPass.targeted_ops, + ): + return True + + # We should not fold q-dq nodes into non-quantized nodes. + if not ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and ArmAnnotationInfo( + node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY] + ).quantized + ): + return False + return True + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: n = cast(Node, n) - if n.op != "call_function": - continue - # Don't fold chains of quant-ops into each other. - if n.target in (*Q_OPS, *DQ_OPS): + if not FoldAndAnnotateQParamsPass.is_foldable(n): continue # Make sure we haven't already set qparams meta information on the node diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index ade287a0cee..27de85e5ba9 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -235,8 +235,8 @@ def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: if node.op != "call_function" or node not in self.table_ops: continue - input_qparams = node.meta["input_qparams"] - output_qparams = node.meta["output_qparams"] + input_qparams = node.meta.get("input_qparams", {}) + output_qparams = node.meta.get("output_qparams", {}) if len(input_qparams) == 0 or len(output_qparams) == 0: # We only want to replace the node if it's quantized continue diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index dcf945d5bb4..2015ab61834 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -20,7 +20,7 @@ class SimpleQuantizeModel(torch.nn.Module): } def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + torch.max((x + x), (y + y)) + return x + torch.maximum((x + x), (y + y)) @common.parametrize("test_data", SimpleQuantizeModel.test_data) diff --git a/backends/transforms/remove_getitem_op.py b/backends/transforms/remove_getitem_op.py index 733393b6d9a..bb08eb4ed25 100644 --- a/backends/transforms/remove_getitem_op.py +++ b/backends/transforms/remove_getitem_op.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -77,6 +78,8 @@ def call(self, graph_module: torch.fx.GraphModule): args=node.args, kwargs=node.kwargs, ) + new_max_wd.meta = node.meta.copy() + new_max_wd.meta["val"] = new_max_wd.meta["val"][0] getitem_node.replace_all_uses_with(new_max_wd) From bfe33da1c66b4119dc2489659b41088d3225cf12 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 24 Nov 2025 10:56:44 +0100 Subject: [PATCH 3/3] Cortex_M backend: Tag quantized nodes Tag quantized nodes with ArmAnnotationInfo to make q-dq-folding pass work as expected. Signed-off-by: Oscar Andersson Change-Id: I2e473c1a268a01a9bf31a85802b56d7d2fd9dc38 --- backends/cortex_m/quantizer/quantizer.py | 36 ++++++++++++++++-------- backends/cortex_m/test/ops/test_add.py | 16 +++++++---- backends/cortex_m/test/ops/test_conv.py | 4 +-- backends/cortex_m/test/ops/test_mul.py | 16 +++++++---- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 8bfc32049ed..b86372fa360 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.operator_configs import ( @@ -19,6 +20,7 @@ ) from executorch.backends.cortex_m.quantizer.quantization_configs import ( INT8_PER_TENSOR_CONFIG, + QuantizationSpec, ) from torch._ops import OpOverload from torch.fx import GraphModule, Node @@ -31,6 +33,20 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +def mark_node_as_annotated( + node: Node, + input_qspec_map: dict[Node, Optional[QuantizationSpec]], + output_qspec: Optional[QuantizationSpec], +) -> None: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(input_qspec_map, output_qspec) + annotation_info = ArmAnnotationInfo( + quantized=True, + ) + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom + + class CortexMQuantizer(ComposableQuantizer): def broadcasting_filter(self, node: Optional[Node]) -> bool: @@ -211,9 +227,7 @@ def annotate_match( if all(node not in match for node in node.users) and output_qspec is None: output_qspec = config.output_activation if config else None - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, output_qspec - ) + mark_node_as_annotated(node, input_qspec_map, output_qspec) def annotate(self, model: GraphModule) -> None: matches = self.match_patterns(model, self.operator_config.operators) @@ -242,8 +256,8 @@ def annotate(self, model: GraphModule) -> None: is_placeholder = node.op == "placeholder" is_filtered = self.filter_fn(node) if is_placeholder and not is_filtered: - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - {}, self.quantization_config.output_activation + mark_node_as_annotated( + node, {}, self.quantization_config.output_activation ) def validate(self, model: GraphModule) -> bool: @@ -271,9 +285,7 @@ def annotate(self, model: GraphModule) -> None: if not self.filter_fn(n) } output_qspec = self.quantization_config.output_activation - output_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, output_qspec - ) + mark_node_as_annotated(output_node, input_qspec_map, output_qspec) def validate(self, model: GraphModule) -> bool: return True @@ -378,10 +390,10 @@ def _annotate_shared_cluster(self, root_node: Node) -> None: shared_qspec = SharedQuantizationSpec(shared_root_node) for node in shared_nodes: - input_qspec_map = {n: shared_qspec for n in node.all_input_nodes} - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map, shared_qspec - ) + input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { + n: shared_qspec for n in node.all_input_nodes + } + mark_node_as_annotated(node, input_qspec_map, shared_qspec) def annotate(self, model: GraphModule) -> None: for node in model.graph.nodes: diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 8c355fd2e39..429557ae1f4 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -139,7 +139,7 @@ class CortexMAlphaAdd(ModelAlpha): } -xfails = { +xfails_implementation = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -152,13 +152,17 @@ class CortexMAlphaAdd(ModelAlpha): "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, ), - "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", +} +xfails_dialect = xfails_implementation | { + # Cortex-M quantizer will not quantize additions that require broadcasting + # leading to the add op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=xfails) +@parametrize("test_case", test_cases, xfails=xfails_dialect) def test_dialect_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -166,7 +170,7 @@ def test_dialect_add(test_case): ) -@parametrize("test_case", test_cases, xfails=xfails) +@parametrize("test_case", test_cases, xfails=xfails_implementation) def test_implementation_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index 8a67d1b7de1..5630abbdab3 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -185,7 +185,6 @@ def forward(self, x): "conv2d_dilation": "NotImplementedError: 'slow_conv_dilated<>' not implemented for 'Int'", "conv1d": "Currently not supported.", "conv2d_nchw": "Currently not supported.", - "conv3d": "Currently not supported.", } @@ -201,7 +200,6 @@ def test_dialect_conv2d(test_case): xfails_implementation = { "conv1d": "Currently not supported.", - "conv2d_nchw": "Currently not supported.", "conv3d": "Currently not supported.", } @@ -209,4 +207,4 @@ def test_dialect_conv2d(test_case): @parametrize("test_case", test_cases, xfails=xfails_implementation) def test_implementation_conv2d(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation(qtol=1) + tester.test_implementation(qtol=2) diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index 35c958ce8d4..8d32a5df92a 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -114,7 +114,7 @@ class CortexMTensorMul(Model): } -xfail_cases = { +xfail_cases_implementation = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -123,13 +123,17 @@ class CortexMTensorMul(Model): "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, ), - "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", - "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", +} +xfail_cases_dialect = xfail_cases_implementation | { + # Cortex-M quantizer will not quantize multiplicaitons that require broadcasting + # leading to the mul op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=xfail_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases_dialect) def test_dialect_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -139,7 +143,7 @@ def test_dialect_mul(test_case): ) -@parametrize("test_case", test_cases, xfails=xfail_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases_implementation) def test_implementation_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation(qtol=1)