Skip to content

Commit 801006d

Browse files
[QNN-EP] Define SpaceToDepth fusion for YOLOv2. (#24848)
### Description <!-- Describe your changes. --> - Add SpaceToDepth fusion for QNN preprocess. - The pattern in YOLOv2 is uncommon while the common seen one is left as future work. - Add entry point/API for non-quantization user to preprocess models for QNN execution. - Revise cmake to package newly introduced directory into Python wheel. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> - While executing YOLOv2 model on QNN-EP, a sequence of Reshape and Transpose having 6D shapes are falling back to CPU due to HTP limitation. Add fusion to fuse this sequence of ops into a single SpaceToDepth which can be directly executed on QNN-EP. - Since current QNN preprocess is provided in `onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py` which is under quantization directory, the path may be confusing for non-quantization users. In order to allow non-quantization users to preprocess models for QNN, introduce `onnxruntime/python/tools/qnn/preprocess.py` to serve as the entry point and provide API to preprocess models.
1 parent 9349c37 commit 801006d

File tree

5 files changed

+317
-0
lines changed

5 files changed

+317
-0
lines changed

cmake/onnxruntime_python.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@ endif()
453453
file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
454454
"${ONNXRUNTIME_ROOT}/python/tools/*.py"
455455
)
456+
file(GLOB onnxruntime_python_tools_qnn_src CONFIGURE_DEPENDS
457+
"${ONNXRUNTIME_ROOT}/python/tools/qnn/*.py"
458+
)
456459
file(GLOB onnxruntime_python_quantization_src CONFIGURE_DEPENDS
457460
"${ONNXRUNTIME_ROOT}/python/tools/quantization/*.py"
458461
)
@@ -564,6 +567,7 @@ add_custom_command(
564567
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/qdq_helpers
565568
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model
566569
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model/ort_flatbuffers_py
570+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/qnn
567571
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers
568572
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models
569573
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
@@ -649,6 +653,9 @@ add_custom_command(
649653
COMMAND ${CMAKE_COMMAND} -E copy_directory
650654
${ONNXRUNTIME_ROOT}/core/flatbuffers/ort_flatbuffers_py
651655
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model/ort_flatbuffers_py
656+
COMMAND ${CMAKE_COMMAND} -E copy
657+
${onnxruntime_python_tools_qnn_src}
658+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/qnn/
652659
COMMAND ${CMAKE_COMMAND} -E copy
653660
${onnxruntime_python_quantization_src}
654661
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
"""Provide entry point to preprocess ONNX model especially for QNN."""
7+
8+
import argparse
9+
import pathlib
10+
11+
import onnx
12+
13+
from onnxruntime.quantization.execution_providers import qnn
14+
15+
16+
def _parse_arguments():
17+
"""Parse cmdline arguments."""
18+
parser = argparse.ArgumentParser(description="Arguments for QNN model preprocess.")
19+
20+
parser.add_argument("--input_model_path", "-i", required=True, help="Path to the input ONNX model.")
21+
parser.add_argument("--output_model_path", "-o", required=True, help="Path to the output ONNX model.")
22+
23+
# Save preprocessed model with external data.
24+
parser.add_argument(
25+
"--save_as_external_data",
26+
action="store_true",
27+
help="Whether the output model would be saved with external data.",
28+
)
29+
parser.add_argument(
30+
"--all_tensors_to_one_file",
31+
action="store_true",
32+
help="Whether to save all external data in one file or save each tensor to a file named with the tensor name.",
33+
)
34+
parser.add_argument(
35+
"--external_data_location",
36+
help="Filename of the external file where all tensors are saved. The path is relative to the model path.",
37+
)
38+
parser.add_argument(
39+
"--external_data_size_threshold",
40+
default=1024,
41+
type=int,
42+
help="Tensors with data size larger than this threshold are converted to external data.",
43+
)
44+
parser.add_argument(
45+
"--external_data_convert_attribute",
46+
action="store_true",
47+
help="Whether to save all tensors, including attribute tensors, to external data.",
48+
)
49+
50+
# Preprocess options.
51+
parser.add_argument(
52+
"--fuse_layernorm",
53+
action="store_true",
54+
help="Whether to fuse matched sequences into LayerNormalization nodes if possible.",
55+
)
56+
57+
# I/O layouts.
58+
parser.add_argument(
59+
"--inputs_to_make_channel_last",
60+
nargs="+",
61+
default=None,
62+
help="List of graph input names to be transposed into channel-last.",
63+
)
64+
65+
parser.add_argument(
66+
"--outputs_to_make_channel_last",
67+
nargs="+",
68+
default=None,
69+
help="List of graph output names to be transposed into channel-last.",
70+
)
71+
72+
return parser.parse_args()
73+
74+
75+
def qnn_preprocess_model(
76+
model_input: str | pathlib.Path | onnx.ModelProto,
77+
model_output: str | pathlib.Path,
78+
fuse_layernorm: bool = False,
79+
save_as_external_data: bool = False,
80+
all_tensors_to_one_file: bool = False,
81+
external_data_location: str | None = None,
82+
external_data_size_threshold: int = 1024,
83+
external_data_convert_attribute: bool = False,
84+
inputs_to_make_channel_last: list[str] | None = None,
85+
outputs_to_make_channel_last: list[str] | None = None,
86+
) -> bool:
87+
"""Preprocess ONNX model for QNN.
88+
89+
Args:
90+
model_input: A path or ONNX ModelProto specifiying the model to be preprocessed.
91+
model_output: A path specifying where the preprocessed model to be saved.
92+
fuse_layernorm: A bool specifying whether to fuse the matched sequence into a single LayerNormalization node.
93+
Defaults to False.
94+
save_as_external_data: A bool specifying whether to save model with external data. Defaults to False.
95+
all_tensors_to_one_file: A bool specifying whether to save all external data in one file or save each tensor to
96+
a file named with the tensor name. This argument is effective only when `save_as_external_data` is True.
97+
Defaults to False.
98+
external_data_location: A str specifying where to save the external data. The path is relative to the model
99+
path. This argument is effective only when `save_as_external_data` is True. Defaults to the model name.
100+
external_data_size_threshold: An int specifying the threshold of data size for tensors be saved as external
101+
data. This argument is effective only when `save_as_external_data` is True. Defaults to 1024.
102+
external_data_convert_attribute: A bool specifying whether to save all tensors including attributes as external
103+
data. This argument is effective only when `save_as_external_data` is True. Defaults to False.
104+
inputs_to_make_channel_last: A list of strs specifying graph input names to be transposed into channel-last.
105+
Defaults to None.
106+
outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last.
107+
Defaults to None.
108+
109+
Returns:
110+
A bool indicating whether the model is modified.
111+
"""
112+
return qnn.qnn_preprocess_model(
113+
model_input,
114+
model_output,
115+
fuse_layernorm=fuse_layernorm,
116+
save_as_external_data=save_as_external_data,
117+
all_tensors_to_one_file=all_tensors_to_one_file,
118+
external_data_location=external_data_location,
119+
external_data_size_threshold=external_data_size_threshold,
120+
external_data_convert_attribute=external_data_convert_attribute,
121+
inputs_to_make_channel_last=inputs_to_make_channel_last,
122+
outputs_to_make_channel_last=outputs_to_make_channel_last,
123+
)
124+
125+
126+
if __name__ == "__main__":
127+
args = _parse_arguments()
128+
qnn_preprocess_model(
129+
args.input_model_path,
130+
args.output_model_path,
131+
fuse_layernorm=args.fuse_layernorm,
132+
save_as_external_data=args.save_as_external_data,
133+
all_tensors_to_one_file=args.all_tensors_to_one_file,
134+
external_data_location=args.external_data_location,
135+
external_data_size_threshold=args.external_data_size_threshold,
136+
external_data_convert_attribute=args.external_data_convert_attribute,
137+
inputs_to_make_channel_last=args.inputs_to_make_channel_last,
138+
outputs_to_make_channel_last=args.outputs_to_make_channel_last,
139+
)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
"""Define SpaceToDepth fusion."""
7+
8+
import onnx
9+
10+
from ... import fusions, onnx_model
11+
12+
13+
class FusionSpaceToDepth(fusions.Fusion):
14+
"""Fusion for SpaceToDepth."""
15+
16+
def __init__(self, model: onnx_model.ONNXModel):
17+
"""Initialize.
18+
19+
Args:
20+
model: An onnx_model.ONNXModel instance.
21+
"""
22+
super().__init__(model, "SpaceToDepth", "Reshape")
23+
24+
def _fuse_yolo(
25+
self,
26+
node: onnx.NodeProto,
27+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
28+
output_name_to_node: dict[str, onnx.NodeProto],
29+
):
30+
"""Fuse for early version of YOLO.
31+
32+
Pattern:
33+
34+
| [N, C, H, W]
35+
Reshape
36+
| [N, C, H/blk, blk, W/blk, blk]
37+
Transpose
38+
| [N, C, H/blk, W/blk, blk, blk]
39+
Reshape
40+
| [N, C, H/blk * W/blk, blk * blk]
41+
Transpose
42+
| [N, C, blk * blk, H/blk * W/blk]
43+
Reshape
44+
| [N, C, blk * blk, H/blk, W/blk]
45+
Transpose
46+
| [N, blk * blk, C, H/blk, W/blk]
47+
Reshape
48+
| [N, blk * blk * C, H/blk, W/blk]
49+
50+
This sequence can be fused into a single SpaceToDepth with blocksize `blk`. Note that unlike DepthToSpace
51+
supporting DCR or CRD mode, SpaceToDepth only supports DCR mode in its latest opset version (13), which matches
52+
the pattern here.
53+
"""
54+
reshape_node1 = node
55+
56+
def get_target_child(parent_node, target_op_type):
57+
"""Get target child of given node."""
58+
if parent_node.output[0] not in input_name_to_nodes:
59+
return None
60+
61+
children = input_name_to_nodes[parent_node.output[0]]
62+
if len(children) > 1 or children[0].op_type != target_op_type:
63+
return None
64+
65+
return children[0]
66+
67+
if (
68+
(transpose_node1 := get_target_child(reshape_node1, "Transpose")) is None
69+
or (reshape_node2 := get_target_child(transpose_node1, "Reshape")) is None
70+
or (transpose_node2 := get_target_child(reshape_node2, "Transpose")) is None
71+
or (reshape_node3 := get_target_child(transpose_node2, "Reshape")) is None
72+
or (transpose_node3 := get_target_child(reshape_node3, "Transpose")) is None
73+
or (reshape_node4 := get_target_child(transpose_node3, "Reshape")) is None
74+
):
75+
return False
76+
77+
def get_tensor_shape(tensor_name):
78+
"""Get shape for given tensor name."""
79+
tensor_type = self.model.get_tensor_type(tensor_name)
80+
if not tensor_type:
81+
return None
82+
83+
tensor_shape = self.tensor_shape_to_list(tensor_type)
84+
if not tensor_shape:
85+
return None
86+
87+
return tensor_shape
88+
89+
if (
90+
(input_shape := get_tensor_shape(reshape_node1.input[0])) is None
91+
or (reshape_shape1 := get_tensor_shape(reshape_node1.output[0])) is None
92+
or (reshape_shape2 := get_tensor_shape(reshape_node2.output[0])) is None
93+
or (reshape_shape3 := get_tensor_shape(reshape_node3.output[0])) is None
94+
or (reshape_shape4 := get_tensor_shape(reshape_node4.output[0])) is None
95+
):
96+
return False
97+
98+
transpose_perm1 = self.get_node_attribute(transpose_node1, "perm")
99+
transpose_perm2 = self.get_node_attribute(transpose_node2, "perm")
100+
transpose_perm3 = self.get_node_attribute(transpose_node3, "perm")
101+
102+
# Check rank.
103+
if (
104+
len(input_shape) != 4
105+
or len(reshape_shape1) != 6
106+
or len(reshape_shape2) != 4
107+
or len(reshape_shape3) != 5
108+
or len(reshape_shape4) != 4
109+
):
110+
return False
111+
112+
# Check shape and perm.
113+
batch, channel, height, width = input_shape
114+
blocksize = reshape_shape1[3]
115+
if (
116+
reshape_shape1 != [batch, channel, height // blocksize, blocksize, width // blocksize, blocksize]
117+
or transpose_perm1 != [0, 1, 2, 4, 3, 5]
118+
or reshape_shape2 != [batch, channel, (height // blocksize) * (width // blocksize), blocksize**2]
119+
or transpose_perm2 != [0, 1, 3, 2]
120+
or reshape_shape3 != [batch, channel, blocksize**2, height // blocksize, width // blocksize]
121+
or transpose_perm3 != [0, 2, 1, 3, 4]
122+
or reshape_shape4 != [batch, blocksize**2 * channel, height // blocksize, width // blocksize]
123+
):
124+
return False
125+
126+
self.nodes_to_remove.extend(
127+
[
128+
reshape_node1,
129+
transpose_node1,
130+
reshape_node2,
131+
transpose_node2,
132+
reshape_node3,
133+
transpose_node3,
134+
reshape_node4,
135+
]
136+
)
137+
138+
s2d_node = onnx.helper.make_node(
139+
self.fused_op_type,
140+
name=self.create_unique_node_name(),
141+
inputs=[reshape_node1.input[0]],
142+
outputs=[reshape_node4.output[0]],
143+
blocksize=blocksize,
144+
)
145+
self.nodes_to_add.append(s2d_node)
146+
147+
return True
148+
149+
def fuse(
150+
self,
151+
node: onnx.NodeProto,
152+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
153+
output_name_to_node: dict[str, onnx.NodeProto],
154+
):
155+
"""Fuse a sequence of Reshape and Transpose nodes into a single SpaceToDepth node.
156+
157+
Args:
158+
node: An onnx.NodeProto matching the specified search type (i.e., Reshape).
159+
input_name_to_nodes: A dict mapping tensor name to consumed nodes.
160+
output_name_to_node: A dict mapping tensor name to produced node.
161+
"""
162+
self._fuse_yolo(node, input_name_to_nodes, output_name_to_node)

onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from ...fusions import FusionGelu, FusionLayerNormalization
1414
from ...onnx_model import ONNXModel
15+
from ...quant_utils import save_and_reload_model_with_shape_infer
1516
from .fusion_lpnorm import FusionLpNormalization
17+
from .fusion_spacetodepth import FusionSpaceToDepth
1618

1719

1820
def qnn_preprocess_model(
@@ -83,6 +85,7 @@ def qnn_preprocess_model(
8385
"""
8486
modified = False
8587
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
88+
model = save_and_reload_model_with_shape_infer(model)
8689
onnx_model = ONNXModel(model)
8790

8891
# Fuse Erf sequence into a single Gelu
@@ -95,6 +98,11 @@ def qnn_preprocess_model(
9598
if fusion_lpnorm.apply():
9699
modified = True
97100

101+
# Fuse Reshape/Transpose sequence into a single SpaceToDepth.
102+
fusion_s2d = FusionSpaceToDepth(onnx_model)
103+
if fusion_s2d.apply():
104+
modified = True
105+
98106
# Optionally, fuse ReduceMean sequence into a single LayerNormalization node.
99107
if fuse_layernorm:
100108
onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def finalize_options(self):
514514
"onnxruntime.tools.ort_format_model.ort_flatbuffers_py",
515515
"onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs",
516516
"onnxruntime.tools.qdq_helpers",
517+
"onnxruntime.tools.qnn",
517518
"onnxruntime.quantization",
518519
"onnxruntime.quantization.operators",
519520
"onnxruntime.quantization.CalTableFlatBuffers",

0 commit comments

Comments
 (0)