Skip to content

Commit f82c2f0

Browse files
[XNNPACK] Add support for Linear fused BatchNorm (pytorch#11805)
Co-authored-by: Digant Desai <[email protected]>
1 parent 73c124c commit f82c2f0

File tree

5 files changed

+312
-218
lines changed

5 files changed

+312
-218
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
)
2222
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
24-
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
25-
FuseBatchNormWithConvPass,
26-
)
24+
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2725
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
2826
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
2927

@@ -60,7 +58,7 @@ def __init__(
6058
ConvertToLinearPass,
6159
ConvertToSDPAPass,
6260
ConstPropPass,
63-
FuseBatchNormWithConvPass,
61+
FuseBatchNormPass,
6462
FuseActivationPass,
6563
DecomposeConcatenate,
6664
RemoveGetItemPass,
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
import torch
10+
from executorch.backends.transforms.utils import (
11+
create_constant_placeholder,
12+
delete_constant_placeholder,
13+
)
14+
15+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16+
17+
from executorch.backends.xnnpack.utils.utils import (
18+
get_param_tensor,
19+
get_tensor_name,
20+
is_param_node,
21+
)
22+
from executorch.exir import ExportedProgram
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from executorch.exir.pass_base import PassResult
25+
from torch.export.graph_signature import InputKind
26+
27+
from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights
28+
29+
30+
class FuseBatchNormPass(XNNPACKPass):
31+
"""
32+
BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase
33+
memory usage since we serialize new weights to represent the convolution. In most cases,
34+
BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused
35+
with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer.
36+
"""
37+
38+
def call(self, graph_module: torch.fx.GraphModule):
39+
graph = graph_module.graph
40+
constant_placeholders_to_delete = set()
41+
for input_node in graph.nodes:
42+
# We want to discover a chain of conv -> batch_norm or linear -> batch_norm.
43+
# Only proceed if the current node is a conv or linear, and has a single user/successor.
44+
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
45+
is_linear = input_node.target == exir_ops.edge.aten.linear.default
46+
47+
if not (is_conv or is_linear) or len(input_node.users) != 1:
48+
continue
49+
50+
# The single user of the conv or linear node must be batch_norm. If not, bail.
51+
bn = list(input_node.users.keys())[0]
52+
if (
53+
bn.target != exir_ops.edge.aten.native_batch_norm.default
54+
and bn.target
55+
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
56+
):
57+
continue
58+
59+
if not self.can_fuse(input_node, bn, self.exported_program):
60+
continue
61+
62+
self._fuse_ops(
63+
graph_module,
64+
graph,
65+
input_node,
66+
bn,
67+
is_conv,
68+
constant_placeholders_to_delete,
69+
)
70+
71+
if len(constant_placeholders_to_delete) > 0:
72+
graph_module.graph.eliminate_dead_code()
73+
for node in constant_placeholders_to_delete:
74+
if (node is not None) and (len(node.users) == 0):
75+
delete_constant_placeholder(self.exported_program, node)
76+
77+
graph_module.recompile()
78+
# To regenerate metadata and shape information, retrace module.
79+
graph_module = super().call(graph_module).graph_module
80+
81+
return PassResult(graph_module, True)
82+
83+
@staticmethod
84+
def can_fuse(
85+
input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
86+
) -> bool:
87+
"""
88+
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
89+
"""
90+
91+
# All users of the batch_norm node must be getitem ops.
92+
# batch_norm returns a 3-element tuple.
93+
# Each user must only access the first element of the tuple.
94+
if [
95+
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
96+
].count(False):
97+
return False
98+
99+
input_node_weights = input_node.args[1]
100+
bn_weights = bn.args[1]
101+
102+
# Check that the weights for conv or linear and batch_norm are both params.
103+
if not isinstance(input_node_weights, torch.fx.Node) or not isinstance(
104+
bn_weights, torch.fx.Node
105+
):
106+
return False
107+
108+
if [
109+
is_param_node(program, node) for node in {input_node_weights, bn_weights}
110+
].count(False):
111+
return False
112+
113+
return True
114+
115+
def _fuse_ops(
116+
self,
117+
graph_module: torch.fx.GraphModule,
118+
graph: torch.fx.Graph,
119+
input_node: torch.fx.Node,
120+
bn: torch.fx.Node,
121+
is_conv: bool,
122+
constant_placeholders_to_delete: set,
123+
) -> None:
124+
"""
125+
Fuse a BatchNorm node into the preceding convolution or linear node.
126+
Update the fused node's weight and bias, rewire users of the BatchNorm output,
127+
and remove the BatchNorm node.
128+
"""
129+
130+
if is_conv:
131+
assert len(input_node.args) == 9
132+
has_bias_arg = True
133+
else:
134+
# Otherwise, this is a linear node.
135+
# Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias).
136+
assert len(input_node.args) in (2, 3)
137+
has_bias_arg = len(input_node.args) == 3
138+
139+
# Get the weight and bias parameters from the conv or linear op.
140+
input_node_weight = get_param_tensor(self.exported_program, input_node.args[1])
141+
input_node_weight_name = get_tensor_name(
142+
self.exported_program, input_node.args[1]
143+
)
144+
assert input_node_weight is not None
145+
146+
if has_bias_arg:
147+
input_node_bias = get_param_tensor(
148+
self.exported_program, input_node.args[2]
149+
)
150+
input_node_bias_name = get_tensor_name(
151+
self.exported_program, input_node.args[2]
152+
)
153+
else:
154+
input_node_bias = None
155+
input_node_bias_name = ""
156+
157+
# Get the parameters from the batch_norm op.
158+
assert (
159+
bn.target == exir_ops.edge.aten.native_batch_norm.default
160+
and len(bn.args) == 8
161+
) or (
162+
bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default
163+
and len(bn.args) == 7
164+
)
165+
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
166+
bn_bias = get_param_tensor(self.exported_program, bn.args[2])
167+
168+
running_mean = get_param_tensor(self.exported_program, bn.args[3])
169+
assert running_mean is not None
170+
171+
running_var = get_param_tensor(self.exported_program, bn.args[4])
172+
assert running_var is not None
173+
174+
# args[7] for native_batch_norm, but args[6] for
175+
# _native_batch_norm_legit_no_training (which doesn't have training
176+
# as an arg).
177+
eps = bn.args[-1]
178+
179+
# Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op.
180+
fuse_args = (
181+
input_node_weight,
182+
input_node_bias,
183+
running_mean,
184+
running_var,
185+
eps,
186+
bn_weight,
187+
bn_bias,
188+
)
189+
190+
if is_conv:
191+
is_transpose = input_node.args[6]
192+
fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose)
193+
else:
194+
# Otherwise, this is a linear node.
195+
fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args)
196+
197+
fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_")
198+
if input_node_bias_name == "":
199+
fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace(
200+
".", "_"
201+
)
202+
else:
203+
fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_")
204+
205+
# Modify the graph by updating the weight and bias of the conv or linear op
206+
# with the fused weight and bias params, and replacing all the users
207+
# of getitem(batch_norm) with the conv or linear op.
208+
with graph.inserting_before(input_node.args[1]):
209+
fused_op_weight_node = create_constant_placeholder(
210+
exp_program=self.exported_program,
211+
graph=graph_module.graph,
212+
kind=InputKind.PARAMETER,
213+
name=fused_weight_name,
214+
data=fused_weight,
215+
)
216+
if fused_bias is not None:
217+
fused_op_bias_node = create_constant_placeholder(
218+
exp_program=self.exported_program,
219+
graph=graph_module.graph,
220+
kind=InputKind.PARAMETER,
221+
name=fused_bias_name,
222+
data=fused_bias,
223+
)
224+
else:
225+
fused_op_bias_node = None
226+
227+
# Replace the original weight and bias with the fused batch_norm values.
228+
args = list(input_node.args)
229+
args[1] = fused_op_weight_node
230+
231+
if has_bias_arg:
232+
# Overwrite original bias with the fused bias.
233+
args[2] = fused_op_bias_node
234+
elif fused_op_bias_node is not None:
235+
# Add the fused bias as a new argument if no bias had originally existed in the input_node.
236+
args.append(fused_op_bias_node)
237+
238+
input_node.args = tuple(args)
239+
240+
# Remove any use of batch_norm from the graph.
241+
for user in bn.users.copy():
242+
assert user.target == operator.getitem
243+
user.replace_all_uses_with(input_node)
244+
graph.erase_node(user)
245+
246+
graph.erase_node(bn)
247+
constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5])

0 commit comments

Comments
 (0)