Skip to content

Commit 381bf9b

Browse files
committed
- Add llm manager
- Support recompose rms norm by pattern-based - Leaverage AttentionMaskInterface and AttentionInterface without touching model structure - Add Eval script to evaluate ppl on device
1 parent 1ceb59b commit 381bf9b

File tree

14 files changed

+961
-538
lines changed

14 files changed

+961
-538
lines changed

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def call(self, graph_module: torch.fx.GraphModule):
105105
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
106106
if node.target == torch.ops.aten.conv1d.default:
107107
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
108-
groups = node.args[6] if num_args > 5 else 1
108+
groups = node.args[6] if num_args > 6 else 1
109109
conv_args = (
110110
qdq_node_after_unsqueeze,
111111
node.args[1],

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_capture_program_passes():
9090
(I64toI32, True),
9191
(LayoutTransform, True),
9292
(RecomposePixelUnshuffle, True),
93-
(RecomposeRmsNorm, False),
93+
(RecomposeRmsNorm, True),
9494
(Remove0DTensor, True),
9595
(RemoveRedundancy, True),
9696
(TagQuantIO, False),
@@ -188,6 +188,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
188188
self.add_pass(RemoveRedundancy(quantization_capture=True))
189189
self.add_pass(ReduceDynamicRange())
190190
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
191+
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
191192
self.add_pass(ReplaceArangeArgs())
192193
self.add_pass(DecomposeCDist())
193194
self.add_pass(DecomposeScaledDotProductAttention())

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,84 +3,95 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from executorch.backends.qualcomm._passes.utils import find_patterns
67
import torch
78

8-
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
9-
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
109
from executorch.exir.dialects._ops import ops as exir_ops
1110
from executorch.exir.pass_base import ExportPass, PassResult
12-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1311

12+
def _is_node(node): return isinstance(node, torch.fx.Node)
13+
def _is_call(node): return _is_node(node) and node.op == 'call_function'
14+
def _is_placeholder(node): return _is_node(node) and node.op == 'placeholder'
15+
def _is_get_attr(node): return _is_node(node) and node.op == 'get_attr'
16+
def _is_add(node): return _is_call(node) and node.target in [exir_ops.edge.aten.add.Tensor, torch.ops.aten.add.Tensor]
17+
def _is_mean(node): return _is_call(node) and node.target in [exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim]
18+
def _is_mul(node): return _is_call(node) and node.target in [exir_ops.edge.aten.mul.Tensor, torch.ops.aten.mul.Tensor]
19+
def _is_pow(node): return _is_call(node) and node.target in [exir_ops.edge.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar]
20+
def _is_rsqrt(node): return _is_call(node) and node.target in [exir_ops.edge.aten.rsqrt.default, torch.ops.aten.rsqrt.default]
1421

1522
class RecomposeRmsNorm(ExportPass):
1623
"""
1724
Merge decomposed operators back to one super node.
18-
TODO: After replacing export_to_edge with to_edge_transform_and_lowering
19-
in examples/models/llama/export_llama_lib.py, this pass can be removed
2025
"""
2126

22-
def __init__(self, edge_program: torch.export.ExportedProgram):
27+
def __init__(self, quantization_capture=False):
2328
super(RecomposeRmsNorm, self).__init__()
24-
self.edge_program = edge_program
25-
26-
def _get_eps_node(self, nodes):
27-
# eps: one of inputs of add node
28-
add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0]
29-
for a in add_node.args:
30-
if isinstance(a, float) or a.op != "call_function":
31-
return a
32-
33-
def _get_gamma_node(self, output_node):
34-
# gamma: one of inputs of output node
35-
for a in output_node.args:
36-
if a.op != "call_function" or a.target in dq_ops:
37-
return a
29+
self.rms_norm_target = exir_ops.edge.aten.rms_norm.default
30+
self.skip_targets = [exir_ops.edge.aten.to.dtype,]
31+
if quantization_capture:
32+
self.rms_norm_target = torch.ops.aten.rms_norm.default
33+
self.skip_targets = [torch.ops.aten.to.dtype,]
34+
35+
def _get_input_node(self, node):
36+
input_node = node.args[0]
37+
while input_node.target in self.skip_targets:
38+
input_node = input_node.args[0]
39+
return input_node
3840

3941
def call(self, graph_module: torch.fx.GraphModule):
4042
graph = graph_module.graph
41-
partitions = get_source_partitions(
42-
graph, [torch.nn.RMSNorm, torch.ops.aten.rms_norm.default]
43-
)
44-
for _, src_partitions in partitions.items():
45-
for src_partition in src_partitions:
46-
input_len = len(src_partition.input_nodes)
47-
if input_len == 1:
48-
input_node = src_partition.input_nodes[0]
49-
elif input_len == 2:
50-
inp_0, inp_1 = src_partition.input_nodes
51-
input_node = inp_0 if len(inp_0.users) == 2 else inp_1
52-
else:
53-
raise RuntimeError(
54-
f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs"
55-
)
5643

57-
output_node = src_partition.output_nodes[0]
58-
eps = self._get_eps_node(src_partition.nodes)
59-
if isinstance(eps, torch.fx.Node) and is_parameter(
60-
eps, self.edge_program
61-
):
62-
eps = get_parameter(eps, self.edge_program).item()
63-
gamma_node = self._get_gamma_node(output_node)
44+
# Root Mean Square normalization math equivalent implementation
45+
patterns = [
46+
# transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
47+
[_is_mul, '*', _is_mul, _is_rsqrt, _is_add, _is_mean, _is_pow],
48+
# executorch.examples.models.llama.norm.RMSNorm
49+
[_is_mul, '*', _is_mul, _is_rsqrt, _is_add, _is_mean, _is_mul],
50+
]
51+
52+
for node in graph.nodes:
53+
if not _is_mul(node):
54+
continue
55+
56+
rms_norm_patterns = [pattern for pattern in find_patterns(node, patterns) if pattern is not None]
57+
58+
if len(rms_norm_patterns)>0:
59+
# Use first matched pattern
60+
rms_norm_pattern = rms_norm_patterns[0][0]
61+
last_mul_node = rms_norm_pattern[0]
62+
gamma_node = None
63+
# weight should be a constant
64+
for arg in last_mul_node.args:
65+
if _is_get_attr(arg) or _is_placeholder(arg):
66+
gamma_node = arg
67+
if gamma_node is None:
68+
continue
69+
70+
eps = rms_norm_pattern[4].args[1]
71+
if isinstance(eps, torch.fx.Node):
72+
eps = eps.meta['val'].constant.item()
73+
input_node = self._get_input_node(rms_norm_pattern[6])
6474

65-
with graph.inserting_before(output_node):
75+
with graph.inserting_before(last_mul_node):
6676
# args schema
6777
# (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
6878
rms_node = graph.create_node(
6979
"call_function",
70-
exir_ops.edge.aten.rms_norm.default,
80+
self.rms_norm_target,
7181
(
7282
input_node,
7383
list(gamma_node.meta["val"].shape),
7484
gamma_node,
7585
eps,
7686
),
7787
)
78-
users = output_node.users.copy()
88+
users = last_mul_node.users.copy()
7989
for user in users:
80-
user.replace_input_with(output_node, rms_node)
90+
user.replace_input_with(last_mul_node, rms_node)
8191
# copy metadata
82-
rms_node.meta = output_node.meta
92+
rms_node.meta = last_mul_node.meta
8393

8494
graph.eliminate_dead_code()
8595
graph_module.recompile()
8696
return PassResult(graph_module, True)
97+

backends/qualcomm/_passes/utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Callable, Dict, List
88

99
import torch
1010
from executorch.backends.qualcomm.builders.utils import get_parameter
@@ -121,3 +121,68 @@ def is_float_tensor(node: torch.fx.Node) -> bool:
121121
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
122122
return False
123123
return node.meta["val"].dtype == torch.float32
124+
125+
def _find_pattern(node: torch.fx.Node, pattern: List[Callable[[torch.fx.Node], bool] | str], from_args: bool=True, max_wildcard_life: int=3, verbose: bool=False):
126+
'''Implement wildcard pattern matching
127+
- node: fx.Node
128+
- pattern: predicate list, can contain followings
129+
Callable(fx.node): predicate
130+
'*': wildcard
131+
- from_args: if True find from node.args, otherwise from node.users
132+
- max_wildcard_life: max number of skips for wildcard
133+
134+
If not matched, return None.
135+
Otherwise, return list of matched node list, which is the same length as pattern
136+
'''
137+
def _is_node(node): return isinstance(node, torch.fx.Node)
138+
def _pred(node, pat): return isinstance(pat, Callable) and pat(node)
139+
def _next(node):
140+
if from_args:
141+
yield from [i for i in node.args if _is_node(i)]
142+
else:
143+
yield from [i for i in node.users]
144+
145+
asterisk = '*'
146+
147+
def _probe(cur, hist, pat_idx, asterisk_life_count=max_wildcard_life, verbose=verbose):
148+
if pat_idx == len(pattern):
149+
assert len(hist) == len(pattern)
150+
if list(hist) not in matched:
151+
matched.append(list(hist))
152+
return
153+
if verbose:
154+
print(f"cur:{cur}, idx:{pat_idx}, life={asterisk_life_count}, pattern:{pattern[pat_idx]} hist={hist}")
155+
if _pred(cur, pattern[pat_idx]):
156+
hist.append(cur)
157+
for child in _next(cur):
158+
_probe(child, hist, pat_idx+1)
159+
hist.pop(-1)
160+
elif pattern[pat_idx] == asterisk and asterisk_life_count>0:
161+
# 3 cases: ignore/consume/keep asterisk
162+
# 1, Ignore asterisk
163+
hist.append(None)
164+
_probe(cur, hist, pat_idx+1)
165+
hist.pop(-1)
166+
167+
# 2. Consume asterisk
168+
hist.append(None)
169+
for child in _next(cur):
170+
_probe(child, hist, pat_idx+1)
171+
hist.pop(-1)
172+
173+
# 3. keep asterisk and skip to next node
174+
for child in _next(cur):
175+
_probe(child, hist, pat_idx, asterisk_life_count-1)
176+
177+
matched = []
178+
_probe(node, [], 0)
179+
return matched if matched else None
180+
181+
182+
def find_patterns(node, patterns, **kwargs):
183+
assert isinstance(patterns, list) and isinstance(patterns[0], list)
184+
results = []
185+
for pattern in patterns:
186+
result = _find_pattern(node, pattern, **kwargs)
187+
results.append(result)
188+
return results

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
QNN_TENSOR_TYPE_MAP = {
6060
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
6161
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
62+
# Note that there is no float64 tensor data type in Qnn.
63+
torch.float64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
6264
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
6365
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
6466
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def annotate_single_in_share_out(
127127
_annotated=True,
128128
)
129129

130-
131130
def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> None:
132131
if _is_annotated([node]):
133132
return
@@ -163,6 +162,11 @@ def annotate_single_in_single_out(
163162
)
164163

165164

165+
@register_annotator([torch.ops.aten.to.dtype])
166+
def annotate_to_dtype(node: Node, quantization_config: QuantizationConfig) -> None:
167+
annotate_single_in_single_out(node, quantization_config)
168+
169+
166170
@register_annotator([torch.ops.aten.atan.default])
167171
def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None:
168172
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/scripts/build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8585
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
8686
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
8787
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
88+
-DEXECUTORCH_ENABLE_LOGGING=ON \
8889
-DQNN_SDK_ROOT=$QNN_SDK_ROOT \
8990
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
9091
-DANDROID_ABI='arm64-v8a' \
@@ -105,6 +106,7 @@ if [ "$BUILD_AARCH64" = true ]; then
105106
-DANDROID_PLATFORM=android-30 \
106107
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
107108
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
109+
-DEXECUTORCH_ENABLE_LOGGING=ON \
108110
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
109111
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
110112
-B$EXAMPLE_ROOT
@@ -119,6 +121,7 @@ if [ "$BUILD_AARCH64" = true ]; then
119121
-DANDROID_ABI='arm64-v8a' \
120122
-DANDROID_PLATFORM=android-30 \
121123
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
124+
-DEXECUTORCH_ENABLE_LOGGING=ON \
122125
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
123126
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
124127
-B$LLAMA_EXAMPLE_ROOT

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,6 @@ def _to_edge_and_lower_llama( # noqa: C901
977977
dep_table = get_passes_dependency_for_capture_program()
978978
passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True
979979
passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True
980-
passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True
981980
passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
982981
passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
983982
"get_quant_io_dtype_fn"
@@ -1410,14 +1409,14 @@ def _get_source_transforms( # noqa
14101409
transforms.append(get_model_with_r1_r2(optimized_rotation_path))
14111410
transforms.append(replace_attention_to_attention_sha)
14121411
transforms.append(replace_causal_mask)
1413-
transforms.append(replace_rms_norm_with_native_rms_norm)
1412+
# transforms.append(replace_rms_norm_with_native_rms_norm)
14141413
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
14151414
transforms.append(convert_linear_to_conv2d)
14161415
else:
14171416
transforms.append(replace_kv_cache_with_simple_kv_cache)
14181417
transforms.append(replace_sdpa_with_flex_sdpa)
14191418
transforms.append(replace_causal_mask)
1420-
transforms.append(replace_rms_norm_with_native_rms_norm)
1419+
# transforms.append(replace_rms_norm_with_native_rms_norm)
14211420
if optimized_rotation_path:
14221421
transforms.append(fuse_layer_norms)
14231422
transforms.append(get_model_with_r1_r2(optimized_rotation_path))

0 commit comments

Comments
 (0)