diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 81b86992dee..6607b00e79c 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D from .annotate_quant_attrs import AnnotateQuantAttrs from .annotate_stack import AnnotateStack from .annotate_unbind import AnnotateUnbind @@ -16,6 +17,7 @@ from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm +from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim @@ -38,6 +40,7 @@ __all__ = [ + AnnotateAdaptiveAvgPool1D, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, @@ -50,6 +53,7 @@ DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, + DecomposeRoll, DecomposeSilu, ExpandBroadcastTensorShape, FixedLinearKeepDim, diff --git a/backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py b/backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py new file mode 100644 index 00000000000..24e5104e7cb --- /dev/null +++ b/backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py @@ -0,0 +1,45 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.backends.qualcomm.builders.node_visitor import q_ops +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from .utils import get_quant_attrs + + +class AnnotateAdaptiveAvgPool1D(ExportPass): + """ + Add "quant_attrs" to graph nodes' meta from the QDQ information + generated after quantization process. + adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze + """ + + decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default] + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(AnnotateAdaptiveAvgPool1D, self).__init__() + self.edge_program = edge_program + + def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule): + partitions = get_source_partitions( + graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default] + ) + for src_partitions in partitions.values(): + for src_partition in src_partitions: + output = src_partition.output_nodes[0] + if (list(output.users)[0].target) in q_ops: + quant_attrs = get_quant_attrs( + self.edge_program, list(output.users)[0] + ) + for n in src_partition.nodes: + n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() + + def call(self, graph_module: torch.fx.GraphModule): + self._annotate_adaptive_avg_pool1d(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index d9ef9cb691d..64496b71f1c 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -7,6 +7,7 @@ from typing import Any, Dict import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops from executorch.backends.qualcomm.builders.utils import get_parameter from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, @@ -20,7 +21,7 @@ ) from executorch.exir.pass_base import ExportPass, PassResult -from .utils import dq_ops, get_quant_attrs, q_ops +from .utils import get_quant_attrs class AnnotateQuantAttrs(ExportPass): diff --git a/backends/qualcomm/_passes/annotate_stack.py b/backends/qualcomm/_passes/annotate_stack.py index 5fbfde058b2..88ee4e41ee6 100644 --- a/backends/qualcomm/_passes/annotate_stack.py +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -4,11 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import q_ops from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import get_quant_attrs, q_ops +from .utils import get_quant_attrs class AnnotateStack(ExportPass): diff --git a/backends/qualcomm/_passes/annotate_unbind.py b/backends/qualcomm/_passes/annotate_unbind.py index 426285e872b..d9141dbc4c0 100644 --- a/backends/qualcomm/_passes/annotate_unbind.py +++ b/backends/qualcomm/_passes/annotate_unbind.py @@ -4,11 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import dq_ops, get_quant_attrs +from .utils import get_quant_attrs class AnnotateUnbind(ExportPass): diff --git a/backends/qualcomm/_passes/decompose_roll.py b/backends/qualcomm/_passes/decompose_roll.py new file mode 100644 index 00000000000..e13433508f5 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_roll.py @@ -0,0 +1,93 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_nn_module_stack + + +class SliceCopy(torch.nn.Module): + def __init__(self, val_shape, shifts, dims): + super().__init__() + self.val_shape = val_shape + if dims[0] is None: + self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))] + else: + self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)] + self.dims = dims + + def forward(self, x): + if self.dims[0] is None: + y = x.flatten() + y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]])) + return y.view(self.val_shape) + + for shift, dim in zip(self.shifts, self.dims): + x = torch.cat( + ( + x[(slice(None),) * dim + (slice(-shift, None),)], + x[(slice(None),) * dim + (slice(0, -shift),)], + ), + dim=dim, + ) + return x + + +class DecomposeRoll(ExportPass): + """ + Decompose roll into slice and cat. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if "roll" in str(node.target): + input_node, shifts = node.args[0], node.args[1] + dims = node.args[2] if len(node.args) == 3 else None + + # Normalize shifts and dims to lists + shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts] + dims = dims if isinstance(dims, (list, tuple)) else [dims] + + model = SliceCopy(input_node.meta["val"].shape, shifts, dims) + decomposed_module = torch.export.export( + model, (input_node.meta["val"],), strict=True + ).module() + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": input_node} + + for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) + # no need to copy existent 'output' + if decomposed_node.op == "output": + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py index 4fe87604fc1..f6bba3001cb 100644 --- a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py +++ b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass -from .utils import dq_ops - class ExpandBroadcastTensorShape(ExportPass): """ diff --git a/backends/qualcomm/_passes/fold_qdq.py b/backends/qualcomm/_passes/fold_qdq.py index accf66d4c35..7a0ef10385b 100644 --- a/backends/qualcomm/_passes/fold_qdq.py +++ b/backends/qualcomm/_passes/fold_qdq.py @@ -4,14 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops from executorch.backends.qualcomm.builders.utils import is_parameter from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass -from .utils import dq_ops, q_ops - class FoldQDQ(ExportPass): """ diff --git a/backends/qualcomm/_passes/insert_io_qdq.py b/backends/qualcomm/_passes/insert_io_qdq.py index 668f76cd695..e5b15f2d12c 100644 --- a/backends/qualcomm/_passes/insert_io_qdq.py +++ b/backends/qualcomm/_passes/insert_io_qdq.py @@ -7,6 +7,8 @@ import torch +from executorch.backends.qualcomm.builders.node_visitor import q_ops + from executorch.backends.qualcomm.builders.utils import is_parameter from executorch.backends.qualcomm.utils.constants import ( QCOM_ENCODING, @@ -16,8 +18,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import q_ops - class InsertIOQDQ(ExportPass): """ diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 9b3a308813e..757d5cee2c4 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -50,6 +50,7 @@ class TensorOpInfo: aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False), aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), + aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), } @@ -78,7 +79,7 @@ def _build_tensor_constant( # For dtype, in some cases, we cannot use node.args[0] as scalar dtype. # Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type tensor = torch.tensor( - [const_val], + const_val, dtype=( node.args[0].meta["val"].dtype if not is_float_tensor(node) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 63c303eb689..58a11fecb12 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -9,6 +9,7 @@ from typing import Dict from executorch.backends.qualcomm._passes import ( + AnnotateAdaptiveAvgPool1D, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, @@ -21,6 +22,7 @@ DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, + DecomposeRoll, DecomposeSilu, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -73,6 +75,7 @@ def get_capture_program_passes(): # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default. # If a pass is activated, it will be executed by default. default_passes_and_setting = [ + (AnnotateAdaptiveAvgPool1D, True), (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), @@ -128,11 +131,11 @@ def get_to_edge_transform_passes( dep_table: Dict = None, ): # TODO: remove this workaround when target could be correctly detected - from executorch.backends.qualcomm._passes import utils + from executorch.backends.qualcomm.builders import node_visitor from executorch.exir.dialects._ops import ops as exir_ops - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) + node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() @@ -187,6 +190,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ReplaceArangeArgs()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) @@ -198,6 +202,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): def transform_for_export_pipeline(self, exported_program: ExportedProgram): self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeRoll()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) # this pass will rewrite state_dict, it needs to be accomplished before diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index a5db826ab28..2e2063cdf6e 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch + +from executorch.backends.qualcomm.builders.node_visitor import dq_ops from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import dq_ops - class RecomposeRmsNorm(ExportPass): """ diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 10dcbb07aac..70bb705be73 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -13,19 +13,6 @@ from torch._subclasses import FakeTensor -q_ops = { - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, -} - -dq_ops = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, -} - - def copy_meta(meta: Dict, callback=None): copied = {} for k, v in meta.items(): @@ -73,6 +60,7 @@ def get_passes_dependency_for_capture_program(): dict: A dictionary mapping each pass to its corresponding list of dependencies. """ from executorch.backends.qualcomm._passes import ( + AnnotateAdaptiveAvgPool1D, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, @@ -94,6 +82,7 @@ def get_passes_dependency_for_capture_program(): ) return { + AnnotateAdaptiveAvgPool1D: [RemoveRedundancy], AnnotateQuantAttrs: [ RecomposePixelUnshuffle, ConvertBmmToMatmul, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 5e9520d4c05..d804ba6b6f0 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -11,7 +11,6 @@ import numpy as np import torch -from executorch.backends.qualcomm._passes.utils import dq_ops from executorch.backends.qualcomm.utils.constants import ( QCOM_AXIS, QCOM_AXIS_ORDER, @@ -79,6 +78,18 @@ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, } +q_ops = { + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, +} + +dq_ops = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, +} + class NodeVisitor: """ diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py index 69ea5e005a7..b9ef6accb8a 100644 --- a/backends/qualcomm/builders/op_prelu.py +++ b/backends/qualcomm/builders/op_prelu.py @@ -40,11 +40,12 @@ def define_node( coeff = get_parameter(coeff_node, self.edge_program) coeff_tensor = torch.zeros(input_node.meta["val"].shape, dtype=coeff.dtype) # per-channel activation - if coeff_node.meta["val"].shape[0] > 1: + coeff_node_shape = coeff_node.meta["val"].shape + if len(coeff_node_shape) and coeff_node_shape[0] > 1: for i in range(input_node.meta["val"].shape[1]): coeff_tensor = coeff_tensor.index_fill(1, torch.tensor([i]), coeff[i]) else: - coeff_tensor.fill_(coeff[0]) + coeff_tensor.fill_(coeff[0] if coeff.dim() else coeff) if axis_order := input_node.meta.get(QCOM_AXIS_ORDER, None): coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..f44dae54b84 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -233,6 +233,26 @@ def annotate_lt(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.masked_fill.Tensor]) +def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + for input_node in node.args: + assert isinstance(input_node, Node) + if _is_float_tensor(input_node): + input_qspec_map[input_node] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=( + quantization_config.output_activation if _is_float_tensor(node) else None + ), + _annotated=True, + ) + + @register_annotator([torch.ops.aten.mul, torch.ops.aten.mul.Tensor]) def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) @@ -462,8 +482,13 @@ def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default]) -def annotate_adaptive_avgpool2d( +@register_annotator( + [ + torch.ops.aten.adaptive_avg_pool1d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + ] +) +def annotate_adaptive_avgpool( node: Node, quantization_config: QuantizationConfig ) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 71edb6303a1..91e59c498af 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -40,6 +40,15 @@ def forward(self, x): return torch.abs(x) +class AdaptiveAvgPool1D(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + adaptive_avg_pool = torch.nn.AdaptiveAvgPool1d(1) + return adaptive_avg_pool(x) + + class AdaptiveAvgPool2D(torch.nn.Module): def __init__(self): super().__init__() @@ -1311,6 +1320,16 @@ def forward(self, x): return self.rms(x) +class Roll(torch.nn.Module): + def __init__(self, shifts, dims=None): + super().__init__() + self.shifts = shifts + self.dims = dims + + def forward(self, x): + return torch.roll(x, shifts=self.shifts, dims=self.dims) + + class Rsqrt(torch.nn.Module): def __init__(self): super().__init__() @@ -1618,6 +1637,16 @@ def forward(self, x): ) +class MaskedFill(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, attn_mask): + return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + # Mimi Decoder has 0D tensor which QNN cannot handle. class ZeroDimTensor(torch.nn.Module): def __init__(self): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 52fc3041c7f..3f8e46524f4 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -119,6 +119,11 @@ def test_qnn_backend_abs(self): sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): + module = AdaptiveAvgPool1D() # noqa: F405 + sample_input = (torch.randn(1, 512, 7),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool2d(self): module = AdaptiveAvgPool2D() # noqa: F405 sample_input = (torch.randn(1, 512, 7, 7),) @@ -799,6 +804,18 @@ def test_qnn_backend_rms_norm(self): sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_roll(self): + modules = [ + Roll(shifts=(3, 3), dims=(1, 2)), # noqa: F405 + Roll(shifts=(70, 59), dims=(1, 2)), # noqa: F405 + Roll(shifts=3), # noqa: F405 + Roll(shifts=56 * 56 * 96 + 3), # noqa: F405 + ] + sample_input = (torch.randn([1, 56, 56, 96]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -907,6 +924,18 @@ def test_qnn_backend_where(self): for i, module in enumerate(modules): self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_masked_fill(self): + module = MaskedFill() # noqa: F405 + attn_mask = torch.ones((64, 49, 49), dtype=torch.float32) + + # Add some zero blocks to simulate the masking behavior + for i in range(64): + if i % 2 == 0: + attn_mask[i, 35:, 35:] = 0 + + sample_input = (attn_mask,) # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + class TestQNNFloatingPointModel(TestQNN): # TODO: refactor to support different backends @@ -1161,6 +1190,12 @@ def test_qnn_backend_abs(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): + module = AdaptiveAvgPool1D() # noqa: F405 + sample_input = (torch.randn(1, 512, 7),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool2d(self): module = AdaptiveAvgPool2D() # noqa: F405 sample_input = (torch.randn(1, 512, 7, 7),) @@ -1550,13 +1585,15 @@ def test_qnn_backend_elu(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_embedding(self): - module = Embedding() # noqa: F405 + modules = [Embedding(), Embedding()] # noqa: F405 sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),) quant_dtype = [QuantDtype.use_8a8w, QuantDtype.use_16a4w] for i, qdtype in enumerate(quant_dtype): with self.subTest(i=i): - module = self.get_qdq_module(module, sample_input, quant_dtype=qdtype) - self.lower_module_and_test_output(module, sample_input) + modules[i] = self.get_qdq_module( + modules[i], sample_input, quant_dtype=qdtype + ) + self.lower_module_and_test_output(modules[i], sample_input) def test_qnn_backend_equal(self): test_comb = [ @@ -1973,6 +2010,19 @@ def test_qnn_backend_rms_norm(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_roll(self): + modules = [ + Roll(shifts=(3, 3), dims=(1, 2)), # noqa: F405 + Roll(shifts=(70, 59), dims=(1, 2)), # noqa: F405 + Roll(shifts=3), # noqa: F405 + Roll(shifts=56 * 56 * 96 + 3), # noqa: F405 + ] + sample_input = (torch.randn([1, 56, 56, 96]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -2097,6 +2147,19 @@ def test_qnn_backend_where(self): module = self.get_qdq_module(module, sample_inputs[i]) self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_masked_fill(self): + module = MaskedFill() # noqa: F405 + attn_mask = torch.ones((64, 49, 49), dtype=torch.float32) + + # Add some zero blocks to simulate the masking behavior + for i in range(64): + if i % 2 == 0: + attn_mask[i, 35:, 35:] = 0 + + sample_input = (attn_mask,) # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + class TestQNNQuantizedModel(TestQNN): # TODO: refactor to support different backends @@ -3779,6 +3842,41 @@ def test_efficientSAM(self): else: self.assertGreaterEqual(msg["MIoU"], 0.55) + def test_swin_transformer(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_transformer.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) + def test_esrgan(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 6432d67981a..e89328904a1 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -15,7 +15,7 @@ import numpy as np import torch from executorch import exir -from executorch.backends.qualcomm._passes.utils import dq_ops +from executorch.backends.qualcomm.builders.node_visitor import dq_ops from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset diff --git a/examples/qualcomm/oss_scripts/swin_transformer.py b/examples/qualcomm/oss_scripts/swin_transformer.py new file mode 100644 index 00000000000..11afff0d70d --- /dev/null +++ b/examples/qualcomm/oss_scripts/swin_transformer.py @@ -0,0 +1,197 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from multiprocessing.connection import Client + +import numpy as np + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) + +from transformers import AutoModelForImageClassification +from transformers.models.swin import modeling_swin + + +# Copy from transformers/models/swin/modeling_swin.py in transformers 4.47.1 +# (QCOM) Transform 6D dim to 5D dim +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + # ====================Qualcomm Changed================================= + input_feature = input_feature.view( + batch_size, + height // window_size, + window_size, + width // window_size, + window_size * num_channels, # Merge the last two dimensions + ) + windows = input_feature.permute(0, 1, 3, 2, 4).contiguous() + windows = windows.view(-1, window_size, window_size, num_channels) + # ===================================================================== + return windows + + +# Copy from transformers/models/swin/modeling_swin.py in transformers 4.47.1 +# (QCOM) Transform 6D dim to 5D dim tests on huggingface version (4.47.1) +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + # ====================Qualcomm Changed================================= + windows = windows.view( + -1, + height // window_size, + width // window_size, + window_size, + window_size * num_channels, # Merge the last two dimensions + ) + windows = windows.permute(0, 1, 3, 2, 4).contiguous() + windows = windows.view(-1, height, width, num_channels) + # ===================================================================== + return windows + + +# (QCOM) Replace the original window_partition and window_reverse functions +# in the modeling_swin module with the new ones, due to QNN SDK does not support 6D tensor. +modeling_swin.window_partition = window_partition +modeling_swin.window_reverse = window_reverse + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + data_num = 100 + if args.ci: + inputs = [torch.rand(1, 3, 224, 224)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + module = ( + AutoModelForImageClassification.from_pretrained( + "microsoft/swin-tiny-patch4-window7-224" + ) + .eval() + .to("cpu") + ) + + pte_filename = "swin_qnn_q8" + build_executorch_binary( + module.eval(), + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " "Default ./swin", + default="./swin", + type=str, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e)