Skip to content

Commit 1ceb59b

Browse files
committed
Qualcomm AI Engine Direct - GA Qwen 2.5 0.5B
Summary: - Add a decoder_model_wrapper.py to ensure that the exported model can be fully delegated in Qnn Backend - Add a e2e script to run qwen 2.5 - Support spin quant R3 - Replace Qwen2Attention with QCQwen2Attention - Pre-compute freqs_cos and freqs_sin to bypass rotary embedding - Replace Qwen2RMSNorm with torch.nn,.RMSNorm - Tag quant IO to avoid insering Q/DQ for I/O - Reuse executorch llama runner, llama_main Note that accuracy currently is bad, need to investigate more.
1 parent bbe90bd commit 1ceb59b

File tree

18 files changed

+1019
-8
lines changed

18 files changed

+1019
-8
lines changed

backends/qualcomm/_passes/build_quant_io.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
3939
if QCOM_QUANTIZED_IO in n.meta:
4040
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
4141

42+
spec = []
43+
for user in list(call_delegate[0].users):
44+
spec.append(self._make_spec(user.meta["val"]))
45+
call_delegate[0].meta["spec"] = tuple(spec)
46+
4247
def call(self, graph_module: torch.fx.GraphModule):
4348
self._build(graph_module)
4449
graph_module.graph.eliminate_dead_code()

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
op_ceil,
2222
op_clamp,
2323
op_conv2d,
24+
op_copy,
2425
op_cos,
2526
op_cum_sum,
2627
op_depth_to_space,
@@ -78,6 +79,7 @@
7879
op_sin,
7980
op_skip_ops,
8081
op_slice_copy,
82+
op_slice_scatter,
8183
op_softmax,
8284
op_space_to_depth,
8385
op_split_with_sizes,
@@ -114,6 +116,7 @@
114116
op_ceil,
115117
op_clamp,
116118
op_conv2d,
119+
op_copy,
117120
op_cos,
118121
op_cum_sum,
119122
op_depth_to_space,
@@ -171,6 +174,7 @@
171174
op_sin,
172175
op_skip_ops,
173176
op_slice_copy,
177+
op_slice_scatter,
174178
op_softmax,
175179
op_space_to_depth,
176180
op_split_with_sizes,

backends/qualcomm/builders/op_copy.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
12+
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
15+
from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class Copy(NodeVisitor):
20+
target = ["aten.copy.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
input_node = self.get_node(node.args[1])
31+
input_tensor = self.get_tensor(input_node, node)
32+
copy_inp_tensor_wrapper = self.define_tensor(
33+
input_node,
34+
node,
35+
input_tensor,
36+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37+
nodes_to_wrappers,
38+
)
39+
40+
copy_input_tensors = [copy_inp_tensor_wrapper]
41+
42+
if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS):
43+
quant_attrs = quant_attrs.copy()
44+
# Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none
45+
node.meta[QCOM_QUANT_ATTRS] = quant_attrs
46+
output_tensor = self.get_tensor(node, node)
47+
output_tensor_wrapper = self.define_tensor(
48+
node,
49+
node,
50+
output_tensor,
51+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
52+
nodes_to_wrappers,
53+
)
54+
copy_output_tensors = [output_tensor_wrapper]
55+
56+
copy_op = PyQnnWrapper.PyQnnOpWrapper(
57+
node.name,
58+
QNN_OP_PACKAGE_NAME_QTI_AISW,
59+
OpReshape.op_name,
60+
)
61+
copy_op.AddInputTensors(copy_input_tensors)
62+
copy_op.AddOutputTensors(copy_output_tensors)
63+
64+
return copy_op
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import cast, Dict
2+
3+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
4+
import torch
5+
6+
from executorch.exir.dialects._ops import ops as exir_ops
7+
8+
from .node_visitor import NodeVisitor
9+
from .node_visitor_manager import register_node_visitor
10+
from .qnn_constants import (
11+
OpScatterNd,
12+
QNN_OP_PACKAGE_NAME_QTI_AISW,
13+
)
14+
15+
16+
@register_node_visitor
17+
class SliceScatterVisitor(NodeVisitor):
18+
target = ["aten.slice_scatter.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = self.get_node(node.args[0])
29+
input_tensor = self.get_tensor(input_node, node)
30+
input_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
value_node = self.get_node(node.args[1])
39+
value_tensor = self.get_tensor(value_node, node)
40+
value_tensor_wrapper = self.define_tensor(
41+
value_node,
42+
node,
43+
value_tensor,
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
48+
output_tensor = self.get_tensor(node, node)
49+
output_tensor_wrapper = self.define_tensor(
50+
node,
51+
node,
52+
output_tensor,
53+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
54+
nodes_to_wrappers,
55+
)
56+
dim = cast(int, node.args[2])
57+
if dim < 0:
58+
dim = dim % len(input_tensor.shape)
59+
60+
start = 0 if node.args[3] is None else cast(int, node.args[3])
61+
if start < 0:
62+
start = start % input_tensor.shape[dim]
63+
64+
if len(node.args) > 4:
65+
end = min(cast(int, node.args[4]), input_tensor.shape[dim])
66+
if end < 0:
67+
end = end % input_tensor.shape[dim]
68+
else:
69+
end = input_tensor.shape[dim]
70+
71+
step = node.args[5] if len(node.args) > 5 else 1
72+
73+
target_index_shape = []
74+
ranges = []
75+
# Collect the index
76+
for i in range(dim+1):
77+
if i == dim:
78+
target_range = torch.tensor(range(start, end, step), dtype=torch.int32)
79+
target_index_shape.append(target_range.size(-1))
80+
ranges.append(target_range)
81+
break
82+
else:
83+
size = input_tensor.size(i)
84+
target_index_shape.append(size)
85+
ranges.append(torch.arange(size, dtype=torch.int32))
86+
# last dim means x-tuple index
87+
target_index_shape.append(dim+1)
88+
target_index_tensor = torch.cartesian_prod(*ranges).reshape(target_index_shape).contiguous()
89+
90+
91+
target_index_node = torch.fx.Node(
92+
node.graph,
93+
node.name + "_target_index",
94+
"call_function",
95+
exir_ops.edge.aten.tensor.default,
96+
(), # args
97+
{}, # kwargs
98+
)
99+
target_index_tensor_wrapper = self.define_tensor(
100+
target_index_node,
101+
node,
102+
target_index_tensor,
103+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
104+
nodes_to_wrappers,
105+
)
106+
107+
index_put_op = PyQnnWrapper.PyQnnOpWrapper(
108+
node.name,
109+
QNN_OP_PACKAGE_NAME_QTI_AISW,
110+
OpScatterNd.op_name,
111+
)
112+
index_put_op.AddInputTensors(
113+
[
114+
input_tensor_wrapper,
115+
target_index_tensor_wrapper,
116+
value_tensor_wrapper,
117+
]
118+
)
119+
index_put_op.AddOutputTensors([output_tensor_wrapper])
120+
121+
return index_put_op

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
not_supported_operator = [
1313
exir_ops.edge.aten.clone.default,
14-
exir_ops.edge.aten.slice_scatter.default,
15-
exir_ops.edge.aten.copy.default,
1614
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1715
]
1816

backends/qualcomm/quantizer/annotators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,21 @@ def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None
643643
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
644644
annotate_single_in_single_out(node, quantization_config)
645645

646+
@register_annotator([torch.ops.aten.slice_scatter.default])
647+
def annotate_slice_scatter(node: Node, quantization_config: QuantizationConfig) -> None:
648+
input = node.args[0]
649+
value = node.args[1]
650+
651+
input_qspec_map = {}
652+
input_qspec_map[input] = quantization_config.input_activation
653+
input_qspec_map[value] = SharedQuantizationSpec((input, node))
654+
655+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
656+
input_qspec_map=input_qspec_map,
657+
output_qspec=SharedQuantizationSpec((input, node)),
658+
_annotated=True,
659+
)
660+
646661

647662
@register_annotator([torch.ops.aten.sqrt.default])
648663
def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
@@ -1028,6 +1043,7 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:
10281043
torch.ops.aten.conv1d.default,
10291044
torch.ops.aten.conv_transpose2d.input,
10301045
torch.ops.aten.conv_transpose1d.default,
1046+
torch.ops.aten.convolution.default,
10311047
]
10321048
)
10331049
def annotate_conv(node: Node, quantization_config: QuantizationConfig) -> None:

backends/qualcomm/scripts/build.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ if [ "$BUILD_AARCH64" = true ]; then
110110
-B$EXAMPLE_ROOT
111111

112112
cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER
113+
114+
LLAMA_EXAMPLE_ROOT=examples/models/llama
115+
cmake $PRJ_ROOT/$LLAMA_EXAMPLE_ROOT \
116+
-DBUILD_TESTING=OFF \
117+
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
118+
-DCMAKE_BUILD_TYPE=$BUILD_TYPE \
119+
-DANDROID_ABI='arm64-v8a' \
120+
-DANDROID_PLATFORM=android-30 \
121+
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
122+
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
123+
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
124+
-B$LLAMA_EXAMPLE_ROOT
125+
126+
cmake --build $LLAMA_EXAMPLE_ROOT -j$BUILD_JOB_NUMBER
113127
fi
114128

115129
if [ "$BUILD_X86_64" = true ]; then

backends/qualcomm/tests/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ def __init__(self):
228228
def forward(self, x, y):
229229
return torch.cat((y, y, x, x), axis=2)
230230

231+
class CausalMask(torch.nn.Module):
232+
def __init__(self):
233+
super().__init__()
234+
self.register_buffer("causal_mask", torch.zeros((1,1, 1, 128)))
235+
self.mask_length = 128
236+
237+
def forward(self, padding_mask):
238+
self.causal_mask[:, :, :, :self.mask_length] = self.causal_mask[:, :, :, :self.mask_length].masked_fill(
239+
padding_mask, 1
240+
)
241+
return self.causal_mask+1
231242

232243
class CDist(torch.nn.Module):
233244
def __init__(self):
@@ -1592,6 +1603,16 @@ def forward(self, x, y):
15921603
+ self.position_ids[:, : seq_length : self.step]
15931604
)
15941605

1606+
class SliceScatter(torch.nn.Module):
1607+
def __init__(self, dim, start, end, step):
1608+
super().__init__()
1609+
self.dim = dim
1610+
self.start = start
1611+
self.end = end
1612+
self.step = step
1613+
1614+
def forward(self, x, y):
1615+
return x.slice_scatter(y, dim=self.dim, start=self.start, end=self.end, step=self.step)
15951616

15961617
class Softmax(torch.nn.Module):
15971618
def __init__(self, dim):

0 commit comments

Comments
 (0)