Skip to content

Commit 23fe285

Browse files
pytorchbotmcr229
andauthored
[ExecuTorch][Weight Sharing][XNNPACK] load named data map data for xnnpack (#9294)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9152 by @mcr229 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/8/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/8/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/7/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/8/orig @diff-train-skip-merge --------- Co-authored-by: Max Ren <[email protected]>
1 parent 8386b78 commit 23fe285

24 files changed

+993
-60
lines changed

backends/xnnpack/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
3737
# Keeping this OFF by default due to regressions in decode and model load with
3838
# kleidi kernels
3939
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)
40+
41+
# Turning this on cache weights between partitions and methods. If weights
42+
# are shared across methods/partitions then this can reduce load time and
43+
# memory usage
44+
45+
# Keeping this off maintains existing behavior. Turning this on serializes
46+
# execution and initialization of delegates, to be revisited
47+
option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE
48+
"Enable weights cache to cache and manage all packed weights" OFF)
49+
50+
if(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE)
51+
add_definitions(-DENABLE_XNNPACK_WEIGHTS_CACHE)
52+
endif()
4053
if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE)
4154
add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE)
4255
endif()

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ python_library(
1919
"//executorch/exir/passes:const_prop_pass",
2020
"//executorch/exir/passes:memory_format_ops_pass",
2121
"//executorch/exir/program:program",
22+
"//executorch/backends/transforms:utils",
2223
],
2324
)

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
import operator
88

99
import torch
10+
from executorch.backends.transforms.utils import (
11+
create_constant_placeholder,
12+
delete_constant_placeholder,
13+
)
1014

1115
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
1216

13-
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
17+
from executorch.backends.xnnpack.utils.utils import (
18+
get_param_tensor,
19+
get_tensor_name,
20+
is_param_node,
21+
)
1422
from executorch.exir import ExportedProgram
1523
from executorch.exir.dialects._ops import ops as exir_ops
1624
from executorch.exir.pass_base import PassResult
25+
from torch.export.graph_signature import InputKind
1726

1827
from torch.nn.utils.fusion import fuse_conv_bn_weights
1928

@@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
2837

2938
def call(self, graph_module: torch.fx.GraphModule):
3039
graph = graph_module.graph
31-
counter = 0
40+
constant_placeholders_to_delete = set()
3241
for conv in graph.nodes:
3342
# We want to discover a chain of conv -> batch_norm.
3443
# Only proceed if the current node is a conv node, and has a single
@@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
5564
assert len(conv.args) == 9
5665

5766
conv_weight = get_param_tensor(self.exported_program, conv.args[1])
67+
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
5868
assert conv_weight is not None
5969

6070
conv_bias = get_param_tensor(self.exported_program, conv.args[2])
71+
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])
6172

6273
# Get the parameters from the batchnorm op
6374
assert (
@@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
95106
bn_bias,
96107
is_transpose,
97108
)
109+
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
110+
if conv_bias_name == "":
111+
fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace(
112+
".", "_"
113+
)
114+
else:
115+
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")
98116

99117
# Modify the graph by updating the weight and bias of conv op
100118
# with the fused weight and bias params, and replacing all the users
101119
# of getitem(batchnorm) with the conv op.
102-
with graph.inserting_before(conv):
103-
fused_weight_name = f"_fused_with_bn_weight_{counter}"
104-
graph_module.register_parameter(fused_weight_name, fused_weight)
105-
fused_weight_node = graph.get_attr(fused_weight_name)
106-
fused_bias_name = f"_fused_with_bn_bias_{counter}"
107-
graph_module.register_parameter(fused_bias_name, fused_bias)
108-
fused_bias_node = graph.get_attr(fused_bias_name)
109-
110-
# Update the weight and bias of conv op
111-
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
112-
conv_args[1] = fused_weight_node
113-
conv_args[2] = fused_bias_node
114-
conv.args = tuple(conv_args)
120+
with graph.inserting_before(conv.args[1]):
121+
fused_conv_weight_node = create_constant_placeholder(
122+
exp_program=self.exported_program,
123+
graph=graph_module.graph,
124+
kind=InputKind.PARAMETER,
125+
name=fused_weight_name,
126+
data=fused_weight,
127+
)
128+
if fused_bias is not None:
129+
fused_conv_bias_node = create_constant_placeholder(
130+
exp_program=self.exported_program,
131+
graph=graph_module.graph,
132+
kind=InputKind.PARAMETER,
133+
name=fused_bias_name,
134+
data=fused_bias,
135+
)
136+
else:
137+
fused_conv_bias_node = None
138+
139+
conv.args = (
140+
conv.args[0],
141+
fused_conv_weight_node,
142+
fused_conv_bias_node,
143+
*conv.args[3:],
144+
)
145+
115146
# Remove any use of batchnorm from the graph
116147
for user in bn.users.copy():
117148
assert user.target == operator.getitem
118149
user.replace_all_uses_with(conv)
119150
graph.erase_node(user)
120151

121152
graph.erase_node(bn)
153+
constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5])
122154

123-
counter += 1
155+
if len(constant_placeholders_to_delete) > 0:
156+
graph_module.graph.eliminate_dead_code()
157+
for node in constant_placeholders_to_delete:
158+
if (node is not None) and (len(node.users) == 0):
159+
delete_constant_placeholder(self.exported_program, node)
124160

125161
graph_module.recompile()
126162
# To Regenerate meta data and shape information, retrace module

backends/xnnpack/operators/node_visitor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,23 @@
3434
check_or_raise,
3535
get_input_node,
3636
get_param_tensor,
37+
get_tensor_name,
3738
is_param_node,
3839
PERM_NCHW_TO_NHWC,
3940
)
4041

41-
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
42+
from executorch.backends.xnnpack.utils.xnnpack_constants import (
43+
UINT64_MAX,
44+
XNN_INVALID_VALUE_ID,
45+
)
46+
from executorch.exir._serialize._named_data_store import NamedDataStore
4247
from torch.export import ExportedProgram
4348

4449
XNN_TYPE_MAP = {
4550
torch.float32: XNNDatatype.xnn_datatype_fp32,
4651
}
4752

4853
from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
49-
_aligned_size,
50-
_pad_to,
5154
CONSTANT_TENSOR_ALIGNMENT,
5255
)
5356

@@ -86,11 +89,11 @@ def __init__(
8689
self,
8790
exported_program: ExportedProgram,
8891
external_ids: Dict,
89-
constant_data_bytes: bytearray,
92+
named_data_store: NamedDataStore,
9093
) -> None:
9194
self._external_ids = external_ids or {}
9295
self._exported_program = exported_program or None
93-
self._constant_data_bytes = constant_data_bytes
96+
self._named_data_store = named_data_store
9497

9598
@property
9699
def external_ids(self) -> Dict:
@@ -579,11 +582,16 @@ def get_serialized_buffer_index(
579582
ctypes.POINTER(array_type),
580583
).contents
581584

582-
offset = len(self._constant_data_bytes)
585+
named_key = get_tensor_name(self.exported_program, get_attr_node)
586+
if named_key == "":
587+
raise ValueError(f"Tensor from node: {get_attr_node} has no name")
588+
583589
size = const_val.untyped_storage().nbytes()
584-
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
585-
self._constant_data_bytes.extend(
586-
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
590+
xnn_graph.constant_data.append(
591+
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
592+
)
593+
self._named_data_store.add_named_data(
594+
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
587595
)
588596

589597
return buffer_idx

0 commit comments

Comments
 (0)