Skip to content

Qualcomm AI Engine Direct - GA model enablement (Swin Transformer) #11099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
Expand All @@ -16,6 +17,7 @@
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fixed_linear_keep_dim import FixedLinearKeepDim
Expand All @@ -38,6 +40,7 @@


__all__ = [
AnnotateAdaptiveAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand All @@ -50,6 +53,7 @@
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeRoll,
DecomposeSilu,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down
45 changes: 45 additions & 0 deletions backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.node_visitor import q_ops
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

from .utils import get_quant_attrs


class AnnotateAdaptiveAvgPool1D(ExportPass):
"""
Add "quant_attrs" to graph nodes' meta from the QDQ information
generated after quantization process.
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
"""

decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default]

def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateAdaptiveAvgPool1D, self).__init__()
self.edge_program = edge_program

def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
)
for src_partitions in partitions.values():
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
if (list(output.users)[0].target) in q_ops:
quant_attrs = get_quant_attrs(
self.edge_program, list(output.users)[0]
)
for n in src_partition.nodes:
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()

def call(self, graph_module: torch.fx.GraphModule):
self._annotate_adaptive_avg_pool1d(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)
3 changes: 2 additions & 1 deletion backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_DTYPE,
Expand All @@ -20,7 +21,7 @@
)
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import dq_ops, get_quant_attrs, q_ops
from .utils import get_quant_attrs


class AnnotateQuantAttrs(ExportPass):
Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/_passes/annotate_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.node_visitor import q_ops
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

from .utils import get_quant_attrs, q_ops
from .utils import get_quant_attrs


class AnnotateStack(ExportPass):
Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/_passes/annotate_unbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

from .utils import dq_ops, get_quant_attrs
from .utils import get_quant_attrs


class AnnotateUnbind(ExportPass):
Expand Down
93 changes: 93 additions & 0 deletions backends/qualcomm/_passes/decompose_roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_nn_module_stack


class SliceCopy(torch.nn.Module):
def __init__(self, val_shape, shifts, dims):
super().__init__()
self.val_shape = val_shape
if dims[0] is None:
self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))]
else:
self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)]
self.dims = dims

def forward(self, x):
if self.dims[0] is None:
y = x.flatten()
y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]]))
return y.view(self.val_shape)

for shift, dim in zip(self.shifts, self.dims):
x = torch.cat(
(
x[(slice(None),) * dim + (slice(-shift, None),)],
x[(slice(None),) * dim + (slice(0, -shift),)],
),
dim=dim,
)
return x


class DecomposeRoll(ExportPass):
"""
Decompose roll into slice and cat.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if "roll" in str(node.target):
input_node, shifts = node.args[0], node.args[1]
dims = node.args[2] if len(node.args) == 3 else None

# Normalize shifts and dims to lists
shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts]
dims = dims if isinstance(dims, (list, tuple)) else [dims]

model = SliceCopy(input_node.meta["val"].shape, shifts, dims)
decomposed_module = torch.export.export(
model, (input_node.meta["val"],), strict=True
).module()

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": input_node}

for decomposed_node in decomposed_module.graph.nodes:
copy_nn_module_stack(node, decomposed_node)
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

from .utils import dq_ops


class ExpandBroadcastTensorShape(ExportPass):
"""
Expand Down
3 changes: 1 addition & 2 deletions backends/qualcomm/_passes/fold_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

from .utils import dq_ops, q_ops


class FoldQDQ(ExportPass):
"""
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from executorch.backends.qualcomm.builders.node_visitor import q_ops

from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_ENCODING,
Expand All @@ -16,8 +18,6 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import q_ops


class InsertIOQDQ(ExportPass):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TensorOpInfo:
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
}


Expand Down Expand Up @@ -78,7 +79,7 @@ def _build_tensor_constant(
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
tensor = torch.tensor(
[const_val],
const_val,
dtype=(
node.args[0].meta["val"].dtype
if not is_float_tensor(node)
Expand Down
11 changes: 8 additions & 3 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict

from executorch.backends.qualcomm._passes import (
AnnotateAdaptiveAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand All @@ -21,6 +22,7 @@
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeRoll,
DecomposeSilu,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down Expand Up @@ -73,6 +75,7 @@ def get_capture_program_passes():
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
# If a pass is activated, it will be executed by default.
default_passes_and_setting = [
(AnnotateAdaptiveAvgPool1D, True),
(AnnotateQuantAttrs, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
Expand Down Expand Up @@ -128,11 +131,11 @@ def get_to_edge_transform_passes(
dep_table: Dict = None,
):
# TODO: remove this workaround when target could be correctly detected
from executorch.backends.qualcomm._passes import utils
from executorch.backends.qualcomm.builders import node_visitor
from executorch.exir.dialects._ops import ops as exir_ops

utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)

passes_job = (
passes_job if passes_job is not None else get_capture_program_passes()
Expand Down Expand Up @@ -187,6 +190,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeEinsum())
self.add_pass(DecomposeExpM1())
Expand All @@ -198,6 +202,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
# this pass will rewrite state_dict, it needs to be accomplished before
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/recompose_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from executorch.backends.qualcomm.builders.node_visitor import dq_ops
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

from .utils import dq_ops


class RecomposeRmsNorm(ExportPass):
"""
Expand Down
15 changes: 2 additions & 13 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
from torch._subclasses import FakeTensor


q_ops = {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}

dq_ops = {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
}


def copy_meta(meta: Dict, callback=None):
copied = {}
for k, v in meta.items():
Expand Down Expand Up @@ -73,6 +60,7 @@ def get_passes_dependency_for_capture_program():
dict: A dictionary mapping each pass to its corresponding list of dependencies.
"""
from executorch.backends.qualcomm._passes import (
AnnotateAdaptiveAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand All @@ -94,6 +82,7 @@ def get_passes_dependency_for_capture_program():
)

return {
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
AnnotateQuantAttrs: [
RecomposePixelUnshuffle,
ConvertBmmToMatmul,
Expand Down
Loading
Loading