Skip to content

Commit ba3c16b

Browse files
committed
Qualcomm AI Engine Direct - GA FocalNet
1 parent 95a1db5 commit ba3c16b

File tree

9 files changed

+197
-10
lines changed

9 files changed

+197
-10
lines changed

backends/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
252252

253253
pybind11_extension(PyQnnManagerAdaptor)
254254
pybind11_extension(PyQnnWrapperAdaptor)
255-
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo)
255+
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES RelWithDebInfo)
256256
# Strip unnecessary sections of the binary
257257
pybind11_strip(PyQnnManagerAdaptor)
258258
pybind11_strip(PyQnnWrapperAdaptor)

backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ class AnnotateAdaptiveAvgPool1D(ExportPass):
1919
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
2020
"""
2121

22-
decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default]
23-
2422
def __init__(self, edge_program: torch.export.ExportedProgram):
2523
super(AnnotateAdaptiveAvgPool1D, self).__init__()
2624
self.edge_program = edge_program

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule):
2828
partitions = get_source_partitions(
2929
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
3030
)
31-
for _, src_partitions in partitions.items():
31+
for src_partitions in partitions.values():
3232
for src_partition in src_partitions:
3333
output = src_partition.output_nodes[0]
3434
if (list(output.users)[0].target) in q_ops:

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
2828
partitions = get_source_partitions(
2929
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
3030
)
31-
for _, src_partitions in partitions.items():
31+
for src_partitions in partitions.values():
3232
for src_partition in src_partitions:
3333
if src_partition.input_nodes[0].target in dq_ops:
3434
q_node = src_partition.input_nodes[0].args[0]

backends/qualcomm/quantizer/annotators.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,13 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None
11931193
)
11941194

11951195

1196-
@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
1196+
@register_annotator(
1197+
[
1198+
torch.ops.aten.split_with_sizes.default,
1199+
torch.ops.aten.split.Tensor,
1200+
torch.ops.aten.chunk.default,
1201+
]
1202+
)
11971203
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
11981204
if _is_annotated([node]):
11991205
return

backends/qualcomm/scripts/build.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ CMAKE_X86_64="build-x86"
3030
BUILD_AARCH64="true"
3131
CMAKE_AARCH64="build-android"
3232
CLEAN="true"
33-
BUILD_TYPE="Debug"
33+
BUILD_TYPE="RelWithDebInfo"
3434
BUILD_JOB_NUMBER="16"
3535

3636
if [ -z PYTHON_EXECUTABLE ]; then
@@ -71,7 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then
7171
rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT
7272
else
7373
# Force rebuild flatccrt for the correct platform
74-
cd $BUILD_ROOT/devtools && make clean
74+
cd $BUILD_ROOT/third-party/flatcc && make clean
7575
fi
7676

7777
cd $BUILD_ROOT
@@ -116,7 +116,7 @@ if [ "$BUILD_X86_64" = true ]; then
116116
rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT
117117
else
118118
# Force rebuild flatccrt for the correct platform
119-
cd $BUILD_ROOT/devtools && make clean
119+
cd $BUILD_ROOT/third-party/flatcc && make clean
120120
fi
121121

122122
cd $BUILD_ROOT

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4108,6 +4108,44 @@ def test_fbnet(self):
41084108
self.assertGreaterEqual(msg["top_1"], 60)
41094109
self.assertGreaterEqual(msg["top_5"], 90)
41104110

4111+
def test_focalnet(self):
4112+
if not self.required_envs([self.image_dataset]):
4113+
self.skipTest("missing required envs")
4114+
4115+
cmds = [
4116+
"python",
4117+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/focalnet.py",
4118+
"--dataset",
4119+
self.image_dataset,
4120+
"--artifact",
4121+
self.artifact_dir,
4122+
"--build_folder",
4123+
self.build_folder,
4124+
"--device",
4125+
self.device,
4126+
"--model",
4127+
self.model,
4128+
"--ip",
4129+
self.ip,
4130+
"--port",
4131+
str(self.port),
4132+
]
4133+
if self.host:
4134+
cmds.extend(["--host", self.host])
4135+
if self.shared_buffer:
4136+
cmds.extend(["--shared_buffer"])
4137+
4138+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4139+
with Listener((self.ip, self.port)) as listener:
4140+
conn = listener.accept()
4141+
p.communicate()
4142+
msg = json.loads(conn.recv())
4143+
if "Error" in msg:
4144+
self.fail(msg["Error"])
4145+
else:
4146+
self.assertGreaterEqual(msg["top_1"], 55)
4147+
self.assertGreaterEqual(msg["top_5"], 80)
4148+
41114149
def test_gMLP(self):
41124150
if not self.required_envs([self.image_dataset]):
41134151
self.skipTest("missing required envs")

examples/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ if(NOT PYTHON_EXECUTABLE)
2323
endif()
2424

2525
if(NOT CMAKE_BUILD_TYPE)
26-
set(CMAKE_BUILD_TYPE Debug)
26+
set(CMAKE_BUILD_TYPE RelWithDebInfo)
2727
endif()
2828

2929
# Find prebuilt libraries. executorch package should contain portable_ops_lib,
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
14+
import torch
15+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16+
from executorch.examples.qualcomm.utils import (
17+
build_executorch_binary,
18+
get_imagenet_dataset,
19+
make_output_dir,
20+
parse_skip_delegation_node,
21+
setup_common_args_and_variables,
22+
SimpleADB,
23+
topk_accuracy,
24+
)
25+
from transformers import AutoModelForImageClassification
26+
27+
28+
def main(args):
29+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
30+
31+
# ensure the working directory exist.
32+
os.makedirs(args.artifact, exist_ok=True)
33+
34+
if not args.compile_only and args.device is None:
35+
raise RuntimeError(
36+
"device serial is required if not compile only. "
37+
"Please specify a device serial by -s/--device argument."
38+
)
39+
40+
data_num = 100
41+
if args.ci:
42+
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
46+
else:
47+
inputs, targets, input_list = get_imagenet_dataset(
48+
dataset_path=f"{args.dataset}",
49+
data_size=data_num,
50+
image_shape=(256, 256),
51+
crop_size=224,
52+
)
53+
54+
module = (
55+
AutoModelForImageClassification.from_pretrained("microsoft/focalnet-tiny")
56+
.eval()
57+
.to("cpu")
58+
)
59+
pte_filename = "focalnet_qnn_q8"
60+
build_executorch_binary(
61+
module.eval(),
62+
inputs[0],
63+
args.model,
64+
f"{args.artifact}/{pte_filename}",
65+
inputs,
66+
skip_node_id_set=skip_node_id_set,
67+
skip_node_op_set=skip_node_op_set,
68+
quant_dtype=QuantDtype.use_8a8w,
69+
shared_buffer=args.shared_buffer,
70+
)
71+
72+
if args.compile_only:
73+
return
74+
75+
adb = SimpleADB(
76+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
77+
build_path=f"{args.build_folder}",
78+
pte_path=f"{args.artifact}/{pte_filename}.pte",
79+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
80+
device_id=args.device,
81+
host_id=args.host,
82+
soc_model=args.model,
83+
shared_buffer=args.shared_buffer,
84+
)
85+
adb.push(inputs=inputs, input_list=input_list)
86+
adb.execute()
87+
88+
# collect output data
89+
output_data_folder = f"{args.artifact}/outputs"
90+
make_output_dir(output_data_folder)
91+
92+
adb.pull(output_path=args.artifact)
93+
94+
# top-k analysis
95+
predictions = []
96+
for i in range(data_num):
97+
predictions.append(
98+
np.fromfile(
99+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
100+
)
101+
)
102+
103+
k_val = [1, 5]
104+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
105+
if args.ip and args.port != -1:
106+
with Client((args.ip, args.port)) as conn:
107+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
108+
else:
109+
for i, k in enumerate(k_val):
110+
print(f"top_{k}->{topk[i]}%")
111+
112+
113+
if __name__ == "__main__":
114+
parser = setup_common_args_and_variables()
115+
116+
parser.add_argument(
117+
"-d",
118+
"--dataset",
119+
help=(
120+
"path to the validation folder of ImageNet dataset. "
121+
"e.g. --dataset imagenet-mini/val "
122+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
123+
),
124+
type=str,
125+
required=False,
126+
)
127+
128+
parser.add_argument(
129+
"-a",
130+
"--artifact",
131+
help="path for storing generated artifacts by this example. "
132+
"Default ./focalnet",
133+
default="./focalnet",
134+
type=str,
135+
)
136+
137+
args = parser.parse_args()
138+
try:
139+
main(args)
140+
except Exception as e:
141+
if args.ip and args.port != -1:
142+
with Client((args.ip, args.port)) as conn:
143+
conn.send(json.dumps({"Error": str(e)}))
144+
else:
145+
raise Exception(e)

0 commit comments

Comments
 (0)