Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
DecomposeMaxPool2dPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeQuantNodesPass,
DecomposeRemainderPass,
DecomposeRoundPass,
DecomposeScaledDotProductAttentionPass,
Expand Down Expand Up @@ -186,7 +187,7 @@ def _tosa_pipeline(
]
)

# Fold Q/DQ nodes, insert INT8/INT32 rescales.
# Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes.
self.add_passes(
[
FoldAndAnnotateQParamsPass(exported_program),
Expand All @@ -197,6 +198,7 @@ def _tosa_pipeline(
DecomposeLinearPass(),
InsertRescaleInt32Pass(),
InsertControlFlowRescalesPass(),
DecomposeQuantNodesPass(),
]
)

Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/convert_elu_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def call(self, graph_module: torch.fx.GraphModule):
if not is_quantized:
continue
with graph.inserting_after(node):
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
replace_node = create_node(
graph, exir_ops.edge.aten.elu.default, from_node=node
)
old_args = list(node.args)

alpha = old_args[1] if len(old_args) > 1 else 1.0
Expand Down
23 changes: 16 additions & 7 deletions backends/arm/_passes/convert_minmax_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
)
Expand Down Expand Up @@ -131,15 +134,21 @@ def call(self, graph_module: torch.fx.GraphModule):

for dim in dims:
args = (input_node, dim, True)
input_node = graph_module.graph.create_node(
"call_function", op, args, node.kwargs
input_node = create_node(
graph=graph_module.graph,
op_target=op,
args=args,
kwargs={},
from_node=node,
)

if not keepdims:
input_node = graph_module.graph.create_node(
"call_function",
squeeze_op,
(input_node, dims),
input_node = create_node(
graph=graph_module.graph,
op_target=squeeze_op,
args=(input_node, dims),
kwargs={},
from_node=node,
)

replace_node.replace_all_uses_with(input_node)
Expand Down
156 changes: 156 additions & 0 deletions backends/arm/_passes/decompose_quant_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.decompose_round_pass import DecomposeRoundPass
from executorch.backends.arm.constants import DEQUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeQuantNodesPass(ArmPass):
"""Decomposes quantization nodes into more primitive operations by rewriting the graph
using the two formulas:

quantized value = clamp(round(fp32_value / scale) + zero point, qmin, qmax)

fp32_value = (quantized value - zp) * scale

For quantization nodes, the pass replaces them with:

1. Multiplying the input by the inverse of the scale factor.
2. Rounding the result.
3. Adding the zero point.
4. Clamping the result to [qmin, qmax].
5. Casting to the target data type.

For dequantization nodes, the pass replaces them with:

1. Casting the input to int32.
2. Subtracting the zero point.
3. Casting to float32.
4. Multiplying by the scale factor.

"""

_passes_required_after: Set[Type[ExportPass]] = {DecomposeRoundPass}

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in list(graph_module.graph.nodes):
if node.op != "call_function" or node.target not in (
QUANT_PER_TENSOR_OP,
DEQUANT_PER_TENSOR_OP,
):
continue
if node.target == DEQUANT_PER_TENSOR_OP and all(
user.target == QUANT_PER_TENSOR_OP for user in node.users
):
continue
elif (
node.target == QUANT_PER_TENSOR_OP
and node.all_input_nodes[0].target == DEQUANT_PER_TENSOR_OP
):
continue
modified = True
args = node.args
input_rank = args[0].meta["val"].ndim
x, scale, zero_point, qmin, qmax, dtype = args
# Instead of dividing by scale in quantization, we multiply by 1/scale
# when quantizing.
scale = cast(float, scale)
scale = scale if node.target == DEQUANT_PER_TENSOR_OP else 1.0 / scale
with graph_module.graph.inserting_before(node):
scale_const = create_node(
graph_module.graph,
exir_ops.edge.aten.full.default,
args=((1,) * input_rank, scale),
kwargs={"dtype": torch.float32},
)
zp_const = create_node(
graph_module.graph,
exir_ops.edge.aten.full.default,
args=((1,) * input_rank, zero_point),
kwargs={
"dtype": (
torch.float32
if node.target == QUANT_PER_TENSOR_OP
else torch.int32
)
},
)
if node.target == QUANT_PER_TENSOR_OP:
# TODO MLETORCH-1587: Decompose quantization nodes using more integer arithmetic
scaled = create_node(
graph_module.graph,
exir_ops.edge.aten.mul.Tensor,
args=(x, scale_const),
from_node=node,
)
rounded = create_node(
graph_module.graph,
exir_ops.edge.aten.round.default,
args=(scaled,),
from_node=node,
)
shifted = create_node(
graph_module.graph,
exir_ops.edge.aten.add.Tensor,
args=(rounded, zp_const),
from_node=node,
)
clamped = create_node(
graph_module.graph,
exir_ops.edge.aten.clamp.default,
args=(shifted, float(qmin), float(qmax)),
from_node=node,
)
quantized = create_node(
graph_module.graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
args=(clamped,),
kwargs={"dtype": dtype},
from_node=node,
)
output = quantized
else:
input_casted_to_zp_dtype = create_node(
graph_module.graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
args=(x,),
kwargs={"dtype": torch.int32},
from_node=node,
)
shifted = create_node(
graph_module.graph,
exir_ops.edge.aten.sub.Tensor,
args=(input_casted_to_zp_dtype, zp_const),
from_node=node,
)
casted_to_float = create_node(
graph_module.graph,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
args=(shifted,),
kwargs={"dtype": torch.float32},
from_node=node,
)
dequantized = create_node(
graph_module.graph,
exir_ops.edge.aten.mul.Tensor,
args=(casted_to_float, scale_const),
from_node=node,
)
output = dequantized
node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified=modified)
35 changes: 30 additions & 5 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -16,10 +15,14 @@
is_param_node,
set_node_arg,
)
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass

from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir import ExportedProgram

Expand Down Expand Up @@ -230,15 +233,37 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
submodule.graph.erase_node(node_to_remove)
return

@staticmethod
def is_foldable(node: Node) -> bool:
if node.op != "call_function":
return False
# Don't fold chains of quant-ops into each other.
if node.target in (*Q_OPS, *DQ_OPS):
return False

# Always fold q-dq into constant ops.
if node.target in (
exir_ops.edge.aten.full_like.default,
*ComputeConstantOpsAOTPass.targeted_ops,
):
return True

# We should not fold q-dq nodes into non-quantized nodes.
if not (
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
and ArmAnnotationInfo(
node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY]
).quantized
):
return False
return True

def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function":
continue
# Don't fold chains of quant-ops into each other.
if n.target in (*Q_OPS, *DQ_OPS):
if not FoldAndAnnotateQParamsPass.is_foldable(n):
continue

# Make sure we haven't already set qparams meta information on the node
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op != "call_function" or node not in self.table_ops:
continue
input_qparams = node.meta["input_qparams"]
output_qparams = node.meta["output_qparams"]
input_qparams = node.meta.get("input_qparams", {})
output_qparams = node.meta.get("output_qparams", {})
if len(input_qparams) == 0 or len(output_qparams) == 0:
# We only want to replace the node if it's quantized
continue
Expand Down
44 changes: 44 additions & 0 deletions backends/arm/test/passes/test_decompose_quant_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm._passes import DecomposeQuantNodesPass
from executorch.backends.arm.test.common import parametrize
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline


class Mul(torch.nn.Module):
test_data = {
"randn": (torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)),
"large_randn": (10e10 * torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)),
}

def forward(self, x, y):
return x * y


@parametrize("test_data", Mul.test_data)
def test_decompose_quant_nodes_pass(test_data: Tuple[torch.Tensor]):
module = Mul()
q_dq_ops = {
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
}
# Verify that DecomposeQuantNodesPass removes quantize/dequantize nodes
# and that the output is correct.
pipeline = PassPipeline(
module,
test_data,
quantize=True,
pass_list=[
DecomposeQuantNodesPass,
],
ops_before_pass=q_dq_ops,
ops_not_after_pass=list(q_dq_ops.keys()),
tosa_extensions=["FP"],
)
pipeline.run()
2 changes: 1 addition & 1 deletion backends/arm/test/passes/test_fold_qdq_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SimpleQuantizeModel(torch.nn.Module):
}

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + torch.max((x + x), (y + y))
return x + torch.maximum((x + x), (y + y))


@common.parametrize("test_data", SimpleQuantizeModel.test_data)
Expand Down
Loading