Skip to content

Commit 28cbc2b

Browse files
chunit-quicChun-I Tsai
and
Chun-I Tsai
authored
Qualcomm AI Engine Direct - GA efficientnet (#11212)
- Fix avg pool filter size error - Add conv2d.padding patch for torch repo - Add oss script - Add test case for avg pool Co-authored-by: Chun-I Tsai <[email protected]>
1 parent 81d4e02 commit 28cbc2b

File tree

5 files changed

+246
-25
lines changed

5 files changed

+246
-25
lines changed

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class AvgPool2d(NodeVisitor):
2323
def __init__(self, *args) -> None:
2424
super().__init__(*args)
2525

26+
def _get_filter_size(self, node):
27+
filter_size = cast(List[int], node.args[1])
28+
if len(filter_size) == 1:
29+
filter_size = filter_size + filter_size
30+
return filter_size
31+
2632
def define_node(
2733
self,
2834
node: torch.fx.Node,
@@ -46,31 +52,44 @@ def define_node(
4652
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4753
nodes_to_wrappers,
4854
)
55+
56+
pt_ceil_mode = node.args[4] if len(node.args) >= 4 else False
57+
4958
# kernel info
50-
filter_size = cast(List[int], node.args[1])
51-
if len(filter_size) == 1:
52-
filter_size = filter_size + filter_size
59+
input_shape = input_node.meta["val"].shape
60+
input_h, input_w = input_shape[2], input_shape[3]
61+
filter_size = self._get_filter_size(node)
62+
if pt_ceil_mode:
63+
# filter_size might larger than input_h, input_w, use min of them
64+
filter_size = [min(filter_size[0], input_h), min(filter_size[1], input_w)]
5365
filter_size_shape = [len(filter_size)]
5466

55-
# stride info - default to kernel_size if not given
56-
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
57-
if len(stride) == 1:
58-
stride = stride + stride
59-
stride_shape = [len(stride)]
60-
6167
padding = [0, 0]
6268
if len(node.args) > 3:
6369
padding = cast(List[int], node.args[3])
6470
if len(padding) == 1:
6571
padding = padding + padding
72+
if pt_ceil_mode:
73+
ori_filter_h, ori_filter_w = self._get_filter_size(node)
74+
padding = [
75+
0 if ori_filter_h > input_h else padding[0],
76+
0 if ori_filter_w > input_w else padding[1],
77+
]
78+
6679
padding_shape = [len(padding), len(padding)]
6780

6881
# if ceil mode is True, use ceil instead of floor to compute the output shape
69-
mode = OpPoolAvg2d.RoundingMode.FLOOR
70-
if len(node.args) > 4:
71-
ceil_mode = cast(bool, node.args[4])
72-
if ceil_mode:
73-
mode = OpPoolAvg2d.RoundingMode.CEIL
82+
mode = (
83+
OpPoolAvg2d.RoundingMode.CEIL
84+
if pt_ceil_mode
85+
else OpPoolAvg2d.RoundingMode.FLOOR
86+
)
87+
88+
# stride info - default to kernel_size if not given
89+
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
90+
if len(stride) == 1:
91+
stride = stride + stride
92+
stride_shape = [len(stride)]
7493

7594
count_include_pad = True
7695
if len(node.args) > 5:

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:
967967
@register_annotator(
968968
[
969969
torch.ops.aten.conv2d.default,
970+
torch.ops.aten.conv2d.padding,
970971
torch.ops.aten.conv1d.default,
971972
torch.ops.aten.conv_transpose2d.input,
972973
torch.ops.aten.conv_transpose1d.default,

backends/qualcomm/tests/models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,13 @@ def forward(self, x, y):
147147

148148

149149
class AvgPoolModule(torch.nn.Module):
150-
def __init__(self):
150+
def __init__(self, kernel_size, stride, padding, ceil_mode):
151151
super().__init__()
152152
self.avgPool = torch.nn.AvgPool2d(
153-
kernel_size=(2, 2),
154-
padding=(1, 1),
155-
stride=(1, 1),
153+
kernel_size=kernel_size,
154+
stride=stride,
155+
padding=padding,
156+
ceil_mode=ceil_mode,
156157
count_include_pad=False,
157158
)
158159

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,19 @@ def test_qnn_backend_argmin(self):
163163
self.lower_module_and_test_output(module, sample_input)
164164

165165
def test_qnn_backend_avg_pool2d(self):
166-
module = AvgPoolModule() # noqa: F405
167-
sample_input = (torch.randn(1, 3, 2, 2),)
168-
self.lower_module_and_test_output(module, sample_input)
166+
modules = [
167+
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
168+
AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405
169+
AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405
170+
] # noqa: F405
171+
sample_inputs = [
172+
(torch.randn(1, 3, 2, 2),),
173+
(torch.randn(1, 1280, 7, 7),),
174+
(torch.randn(1, 1280, 7, 7),),
175+
]
176+
for i, module in enumerate(modules):
177+
with self.subTest(i=i):
178+
self.lower_module_and_test_output(module, sample_inputs[i])
169179

170180
def test_qnn_backend_batch_norm(self):
171181
modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405
@@ -1271,10 +1281,20 @@ def test_qnn_backend_argmin(self):
12711281
self.lower_module_and_test_output(module, sample_input)
12721282

12731283
def test_qnn_backend_avg_pool2d(self):
1274-
module = AvgPoolModule() # noqa: F405
1275-
sample_input = (torch.randn(1, 3, 2, 2),)
1276-
module = self.get_qdq_module(module, sample_input)
1277-
self.lower_module_and_test_output(module, sample_input)
1284+
modules = [
1285+
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
1286+
AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405
1287+
AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405
1288+
] # noqa: F405
1289+
sample_inputs = [
1290+
(torch.randn(1, 3, 2, 2),),
1291+
(torch.randn(1, 1280, 7, 7),),
1292+
(torch.randn(1, 1280, 7, 7),),
1293+
]
1294+
for i, module in enumerate(modules):
1295+
with self.subTest(i=i):
1296+
module = self.get_qdq_module(module, sample_inputs[i])
1297+
self.lower_module_and_test_output(module, sample_inputs[i])
12781298

12791299
def test_qnn_backend_batch_norm(self):
12801300
modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405
@@ -3864,6 +3884,41 @@ def test_dino_v2(self):
38643884
self.assertGreaterEqual(msg["top_1"], 70)
38653885
self.assertGreaterEqual(msg["top_5"], 85)
38663886

3887+
def test_efficientnet(self):
3888+
if not self.required_envs([self.image_dataset]):
3889+
self.skipTest("missing required envs")
3890+
cmds = [
3891+
"python",
3892+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientnet.py"
3893+
"--dataset",
3894+
self.image_dataset,
3895+
"--artifact",
3896+
self.artifact_dir,
3897+
"--build_folder",
3898+
self.build_folder,
3899+
"--device",
3900+
self.device,
3901+
"--model",
3902+
self.model,
3903+
"--ip",
3904+
self.ip,
3905+
"--port",
3906+
str(self.port),
3907+
]
3908+
if self.host:
3909+
cmds.extend(["--host", self.host])
3910+
3911+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3912+
with Listener((self.ip, self.port)) as listener:
3913+
conn = listener.accept()
3914+
p.communicate()
3915+
msg = json.loads(conn.recv())
3916+
if "Error" in msg:
3917+
self.fail(msg["Error"])
3918+
else:
3919+
self.assertGreaterEqual(msg["top_1"], 70)
3920+
self.assertGreaterEqual(msg["top_5"], 85)
3921+
38673922
def test_efficientSAM(self):
38683923
if not self.required_envs(
38693924
[self.image_dataset, self.pretrained_weight, self.oss_repo]
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
14+
import torch
15+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16+
from executorch.examples.qualcomm.utils import (
17+
build_executorch_binary,
18+
get_imagenet_dataset,
19+
make_output_dir,
20+
parse_skip_delegation_node,
21+
setup_common_args_and_variables,
22+
SimpleADB,
23+
topk_accuracy,
24+
)
25+
from transformers import AutoModelForImageClassification
26+
27+
28+
def main(args):
29+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
30+
31+
# ensure the working directory exist.
32+
os.makedirs(args.artifact, exist_ok=True)
33+
34+
if not args.compile_only and args.device is None:
35+
raise RuntimeError(
36+
"device serial is required if not compile only. "
37+
"Please specify a device serial by -s/--device argument."
38+
)
39+
40+
data_num = 100
41+
if args.ci:
42+
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
46+
else:
47+
inputs, targets, input_list = get_imagenet_dataset(
48+
dataset_path=f"{args.dataset}",
49+
data_size=data_num,
50+
image_shape=(256, 256),
51+
crop_size=224,
52+
)
53+
54+
module = (
55+
AutoModelForImageClassification.from_pretrained("google/efficientnet-b0")
56+
.eval()
57+
.to("cpu")
58+
)
59+
pte_filename = "efficientnet_qnn_q16"
60+
build_executorch_binary(
61+
module.eval(),
62+
inputs[0],
63+
args.model,
64+
f"{args.artifact}/{pte_filename}",
65+
inputs,
66+
skip_node_id_set=skip_node_id_set,
67+
skip_node_op_set=skip_node_op_set,
68+
quant_dtype=QuantDtype.use_16a16w,
69+
shared_buffer=args.shared_buffer,
70+
)
71+
72+
if args.compile_only:
73+
return
74+
75+
adb = SimpleADB(
76+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
77+
build_path=f"{args.build_folder}",
78+
pte_path=f"{args.artifact}/{pte_filename}.pte",
79+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
80+
device_id=args.device,
81+
host_id=args.host,
82+
soc_model=args.model,
83+
shared_buffer=args.shared_buffer,
84+
)
85+
adb.push(inputs=inputs, input_list=input_list)
86+
adb.execute()
87+
88+
# collect output data
89+
output_data_folder = f"{args.artifact}/outputs"
90+
make_output_dir(output_data_folder)
91+
92+
adb.pull(output_path=args.artifact)
93+
94+
# top-k analysis
95+
predictions = []
96+
for i in range(data_num):
97+
predictions.append(
98+
np.fromfile(
99+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
100+
)
101+
)
102+
103+
k_val = [1, 5]
104+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
105+
if args.ip and args.port != -1:
106+
with Client((args.ip, args.port)) as conn:
107+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
108+
else:
109+
for i, k in enumerate(k_val):
110+
print(f"top_{k}->{topk[i]}%")
111+
112+
113+
if __name__ == "__main__":
114+
parser = setup_common_args_and_variables()
115+
116+
parser.add_argument(
117+
"-d",
118+
"--dataset",
119+
help=(
120+
"path to the validation folder of ImageNet dataset. "
121+
"e.g. --dataset imagenet-mini/val "
122+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
123+
),
124+
type=str,
125+
required=False,
126+
)
127+
128+
parser.add_argument(
129+
"-a",
130+
"--artifact",
131+
help="path for storing generated artifacts by this example. "
132+
"Default ./efficientnet",
133+
default="./efficientnet",
134+
type=str,
135+
)
136+
137+
args = parser.parse_args()
138+
try:
139+
main(args)
140+
except Exception as e:
141+
if args.ip and args.port != -1:
142+
with Client((args.ip, args.port)) as conn:
143+
conn.send(json.dumps({"Error": str(e)}))
144+
else:
145+
raise Exception(e)

0 commit comments

Comments
 (0)