Skip to content

Commit 26b9e52

Browse files
winskuo-quiclucylq
authored andcommitted
Qualcomm AI Engine Direct - GA Static QWEN2.5 0.5B (#12054)
### Summary Static Qwen2.5 0.5b enablement. Please use 16a8w for qwen as other quant configs are not yet fully supported. Script `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8650 --prompt "Hello, how are you?" --temperature 0 --model_mode kv--max_seq_len 128 --ptq 16a8w --decoder_model qwen2_5` #### Stats SM8650 <img width="1674" height="712" alt="image" src="https://github.com/user-attachments/assets/ce162c20-9025-4c1c-b794-93176e4ee677" /> SM8750 <img width="1671" height="816" alt="image" src="https://github.com/user-attachments/assets/25db8a97-8adf-42d4-b8b4-6cdfaf933c69" /> ### Test plan `python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_qwen2_5 --model SM8650 --build_folder build-android/ --executorch_root . -s $DEVICE --artifact ./qwen2_5` Author: @haowhsu-quic, @winskuo-quic
1 parent 7cbc571 commit 26b9e52

File tree

14 files changed

+373
-139
lines changed

14 files changed

+373
-139
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ class LayoutTransform(ExportPass):
101101
exir_ops.edge.aten.pow.Tensor_Scalar,
102102
exir_ops.edge.aten.prelu.default,
103103
exir_ops.edge.aten.repeat.default,
104-
exir_ops.edge.aten.round.default,
105104
exir_ops.edge.aten.relu.default,
105+
exir_ops.edge.aten.round.default,
106106
exir_ops.edge.aten.sigmoid.default,
107107
exir_ops.edge.aten.split_with_sizes.default,
108108
exir_ops.edge.aten.split_with_sizes_copy.default,

backends/qualcomm/quantizer/annotators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) ->
275275
)
276276

277277

278-
@register_annotator([torch.ops.aten.mul, torch.ops.aten.mul.Tensor])
278+
@register_annotator(
279+
[torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor]
280+
)
279281
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
280282
annotate_binary(node, quantization_config)
281283

@@ -1298,7 +1300,7 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
12981300
)
12991301

13001302

1301-
@register_annotator([torch.ops.aten.zeros.default])
1303+
@register_annotator([torch.ops.aten.zeros.default, torch.ops.aten.zeros_like.default])
13021304
def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None:
13031305
if _is_annotated([node]) or not _is_float_tensor(node):
13041306
return

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
153153
)
154154

155155

156-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
156+
def annotate_matmul_16a8w( # noqa: C901
157+
gm: torch.fx.GraphModule, annotate_conv=True
158+
) -> None:
157159
"""
158160
This function is specific for matmul op 16a8w.
159161
For k, we will tag such as the below, and
@@ -317,9 +319,10 @@ def annotate_matmul_input1(node: Node):
317319
# The arguments of cat op: (the past kv cache, the new kv cache)
318320
node = node.args[0][1]
319321
elif node.target == torch.ops.aten.conv2d.default:
320-
annotate_conv2d(
321-
node, quantization_config=quantization_config_8a4w_per_channel
322-
)
322+
if annotate_conv:
323+
annotate_conv2d(
324+
node, quantization_config=quantization_config_8a4w_per_channel
325+
)
323326
break
324327
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
325328
break

backends/qualcomm/scripts/build.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8585
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
8686
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
8787
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
88+
-DEXECUTORCH_ENABLE_LOGGING=ON \
8889
-DQNN_SDK_ROOT=$QNN_SDK_ROOT \
8990
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
9091
-DANDROID_ABI='arm64-v8a' \
@@ -104,6 +105,9 @@ if [ "$BUILD_AARCH64" = true ]; then
104105
-DANDROID_ABI='arm64-v8a' \
105106
-DANDROID_PLATFORM=android-30 \
106107
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
108+
-DSUPPORT_REGEX_LOOKAHEAD=ON \
109+
-DBUILD_TESTING=OFF \
110+
-DEXECUTORCH_ENABLE_LOGGING=ON \
107111
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
108112
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
109113
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
@@ -134,6 +138,7 @@ if [ "$BUILD_X86_64" = true ]; then
134138
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
135139
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
136140
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
141+
-DEXECUTORCH_ENABLE_LOGGING=ON \
137142
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
138143
-S $PRJ_ROOT \
139144
-B $BUILD_ROOT \
@@ -157,6 +162,9 @@ if [ "$BUILD_X86_64" = true ]; then
157162
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
158163
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
159164
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
165+
-DSUPPORT_REGEX_LOOKAHEAD=ON \
166+
-DBUILD_TESTING=OFF \
167+
-DEXECUTORCH_ENABLE_LOGGING=ON \
160168
-B$EXAMPLE_ROOT
161169

162170
cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3999,7 +3999,7 @@ def test_llama3_2_1b(self):
39993999
"16a4w",
40004000
"--temperature",
40014001
"0",
4002-
"--llama_model",
4002+
"--decoder_model",
40034003
"llama3_2",
40044004
"--model_mode",
40054005
"hybrid",
@@ -4079,7 +4079,7 @@ def test_llama_stories_110m(self):
40794079
"16a4w",
40804080
"--temperature",
40814081
"0",
4082-
"--llama_model",
4082+
"--decoder_model",
40834083
"stories110m",
40844084
"--model_mode",
40854085
"hybrid",
@@ -4121,6 +4121,65 @@ def test_llama_stories_110m(self):
41214121
if not self.compile_only and not self.enable_x86_64:
41224122
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
41234123

4124+
def test_qwen2_5(self):
4125+
if not self.required_envs():
4126+
self.skipTest("missing required envs")
4127+
4128+
prompt = "My favourite condiment is "
4129+
cmds = [
4130+
"python",
4131+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4132+
"--artifact",
4133+
self.artifact_dir,
4134+
"--build_folder",
4135+
self.build_folder,
4136+
"--model",
4137+
self.model,
4138+
"--ip",
4139+
self.ip,
4140+
"--port",
4141+
str(self.port),
4142+
"--prompt",
4143+
f"{prompt}",
4144+
"--ptq",
4145+
"16a8w",
4146+
"--decoder_model",
4147+
"qwen2_5",
4148+
"--model_mode",
4149+
"hybrid",
4150+
"--prefill_ar_len",
4151+
"32",
4152+
"--max_seq_len",
4153+
"128",
4154+
]
4155+
if self.compile_only:
4156+
cmds.extend(["--compile_only"])
4157+
elif self.device:
4158+
cmds.extend(["--device", self.device])
4159+
if self.host:
4160+
cmds.extend(["--host", self.host])
4161+
elif self.enable_x86_64:
4162+
cmds.extend(["--enable_x86_64"])
4163+
if self.pre_gen_pte:
4164+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4165+
4166+
# Accuracy is bad for now. Just check user's prompt is returned.
4167+
golden_start_with = "My favourite condiment is "
4168+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4169+
with Listener((self.ip, self.port)) as listener:
4170+
conn = listener.accept()
4171+
p.communicate()
4172+
msg = json.loads(conn.recv())
4173+
if "Error" in msg:
4174+
self.fail(msg["Error"])
4175+
else:
4176+
model_out = msg["result"][0]
4177+
self.assertTrue(
4178+
model_out.startswith(golden_start_with),
4179+
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4180+
)
4181+
self.assertGreaterEqual(msg["inference_speed"], 95) # Lanai
4182+
41244183

41254184
class TestExampleOssScript(TestQNN):
41264185
def test_albert(self):

examples/qualcomm/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ target_include_directories(
7777

7878
# add tokenizers
7979
add_subdirectory(
80-
${EXECUTORCH_ROOT}/extension/llm/tokenizers
81-
${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/tokenizers
80+
${EXECUTORCH_ROOT}/extension/llm/runner
81+
${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/runner
8282
)
8383

8484
# build qnn_executor_runner

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
78
# model sharding with custom op
89
set(CUSTOM_OP_SRCS_FILE
910
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
@@ -63,14 +64,22 @@ target_link_libraries(
6364
executorch_core
6465
extension_data_loader
6566
extension_flat_tensor
67+
extension_llm_runner
6668
extension_module
6769
extension_tensor
70+
tokenizers
6871
gflags
6972
custom_ops
7073
quantized_ops_lib
7174
quantized_kernels
7275
tokenizers
7376
)
77+
78+
target_include_directories(
79+
qnn_llama_runner
80+
PUBLIC ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include
81+
)
82+
7483
target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options})
7584
set_target_properties(
7685
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Summary
22

33
## Overview
4-
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
4+
This file provides you the instructions to run LLM Decoder model with different parameters via Qualcomm HTP backend. We currently support the following models:
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8+
4. QWEN2.5 0.5B
89

910
We offer the following modes to execute the model:
1011

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
def convert_configs(config):
9+
# HF config keys are different from Llama configs.
10+
# Convert the config keys to align with Llama.
11+
if hasattr(config, "hidden_size"):
12+
config.dim = config.hidden_size
13+
delattr(config, "hidden_size")
14+
15+
if hasattr(config, "num_attention_heads"):
16+
config.n_heads = config.num_attention_heads
17+
delattr(config, "num_attention_heads")
18+
19+
if hasattr(config, "num_key_value_heads"):
20+
config.n_kv_heads = config.num_key_value_heads
21+
delattr(config, "num_key_value_heads")
22+
23+
if hasattr(config, "rms_norm_eps"):
24+
config.norm_eps = config.rms_norm_eps
25+
delattr(config, "rms_norm_eps")
26+
27+
if hasattr(config, "rope_theta"):
28+
config.rope_freq_base = config.rope_theta
29+
delattr(config, "rope_theta")
30+
31+
if hasattr(config, "num_hidden_layers"):
32+
config.n_layers = config.num_hidden_layers
33+
delattr(config, "num_hidden_layers")
34+
35+
if hasattr(config, "intermediate_size"):
36+
config.hidden_dim = config.intermediate_size
37+
delattr(config, "intermediate_size")
38+
39+
if hasattr(config, "rope_scaling"):
40+
config.use_scaled_rope = config.rope_scaling
41+
# Use default value of precompute_freq_cis
42+
if not hasattr(config, "rope_scale_factor"):
43+
config.rope_scale_factor = 4
44+
45+
return config

0 commit comments

Comments
 (0)