|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import operator |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.backends.transforms.utils import ( |
| 11 | + create_constant_placeholder, |
| 12 | + delete_constant_placeholder, |
| 13 | +) |
| 14 | + |
| 15 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 16 | + |
| 17 | +from executorch.backends.xnnpack.utils.utils import ( |
| 18 | + get_param_tensor, |
| 19 | + get_tensor_name, |
| 20 | + is_param_node, |
| 21 | +) |
| 22 | +from executorch.exir import ExportedProgram |
| 23 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 24 | +from executorch.exir.pass_base import PassResult |
| 25 | +from torch.export.graph_signature import InputKind |
| 26 | + |
| 27 | +from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights |
| 28 | + |
| 29 | + |
| 30 | +class FuseBatchNormPass(XNNPACKPass): |
| 31 | + """ |
| 32 | + BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase |
| 33 | + memory usage since we serialize new weights to represent the convolution. In most cases, |
| 34 | + BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused |
| 35 | + with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer. |
| 36 | + """ |
| 37 | + |
| 38 | + def call(self, graph_module: torch.fx.GraphModule): |
| 39 | + graph = graph_module.graph |
| 40 | + constant_placeholders_to_delete = set() |
| 41 | + for input_node in graph.nodes: |
| 42 | + # We want to discover a chain of conv -> batch_norm or linear -> batch_norm. |
| 43 | + # Only proceed if the current node is a conv or linear, and has a single user/successor. |
| 44 | + is_conv = input_node.target == exir_ops.edge.aten.convolution.default |
| 45 | + is_linear = input_node.target == exir_ops.edge.aten.linear.default |
| 46 | + |
| 47 | + if not (is_conv or is_linear) or len(input_node.users) != 1: |
| 48 | + continue |
| 49 | + |
| 50 | + # The single user of the conv or linear node must be batch_norm. If not, bail. |
| 51 | + bn = list(input_node.users.keys())[0] |
| 52 | + if ( |
| 53 | + bn.target != exir_ops.edge.aten.native_batch_norm.default |
| 54 | + and bn.target |
| 55 | + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default |
| 56 | + ): |
| 57 | + continue |
| 58 | + |
| 59 | + if not self.can_fuse(input_node, bn, self.exported_program): |
| 60 | + continue |
| 61 | + |
| 62 | + self._fuse_ops( |
| 63 | + graph_module, |
| 64 | + graph, |
| 65 | + input_node, |
| 66 | + bn, |
| 67 | + is_conv, |
| 68 | + constant_placeholders_to_delete, |
| 69 | + ) |
| 70 | + |
| 71 | + if len(constant_placeholders_to_delete) > 0: |
| 72 | + graph_module.graph.eliminate_dead_code() |
| 73 | + for node in constant_placeholders_to_delete: |
| 74 | + if (node is not None) and (len(node.users) == 0): |
| 75 | + delete_constant_placeholder(self.exported_program, node) |
| 76 | + |
| 77 | + graph_module.recompile() |
| 78 | + # To regenerate metadata and shape information, retrace module. |
| 79 | + graph_module = super().call(graph_module).graph_module |
| 80 | + |
| 81 | + return PassResult(graph_module, True) |
| 82 | + |
| 83 | + @staticmethod |
| 84 | + def can_fuse( |
| 85 | + input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram |
| 86 | + ) -> bool: |
| 87 | + """ |
| 88 | + Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. |
| 89 | + """ |
| 90 | + |
| 91 | + # All users of the batch_norm node must be getitem ops. |
| 92 | + # batch_norm returns a 3-element tuple. |
| 93 | + # Each user must only access the first element of the tuple. |
| 94 | + if [ |
| 95 | + (user.target == operator.getitem and user.args[1] == 0) for user in bn.users |
| 96 | + ].count(False): |
| 97 | + return False |
| 98 | + |
| 99 | + input_node_weights = input_node.args[1] |
| 100 | + bn_weights = bn.args[1] |
| 101 | + |
| 102 | + # Check that the weights for conv or linear and batch_norm are both params. |
| 103 | + if not isinstance(input_node_weights, torch.fx.Node) or not isinstance( |
| 104 | + bn_weights, torch.fx.Node |
| 105 | + ): |
| 106 | + return False |
| 107 | + |
| 108 | + if [ |
| 109 | + is_param_node(program, node) for node in {input_node_weights, bn_weights} |
| 110 | + ].count(False): |
| 111 | + return False |
| 112 | + |
| 113 | + return True |
| 114 | + |
| 115 | + def _fuse_ops( |
| 116 | + self, |
| 117 | + graph_module: torch.fx.GraphModule, |
| 118 | + graph: torch.fx.Graph, |
| 119 | + input_node: torch.fx.Node, |
| 120 | + bn: torch.fx.Node, |
| 121 | + is_conv: bool, |
| 122 | + constant_placeholders_to_delete: set, |
| 123 | + ) -> None: |
| 124 | + """ |
| 125 | + Fuse a BatchNorm node into the preceding convolution or linear node. |
| 126 | + Update the fused node's weight and bias, rewire users of the BatchNorm output, |
| 127 | + and remove the BatchNorm node. |
| 128 | + """ |
| 129 | + |
| 130 | + if is_conv: |
| 131 | + assert len(input_node.args) == 9 |
| 132 | + has_bias_arg = True |
| 133 | + else: |
| 134 | + # Otherwise, this is a linear node. |
| 135 | + # Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias). |
| 136 | + assert len(input_node.args) in (2, 3) |
| 137 | + has_bias_arg = len(input_node.args) == 3 |
| 138 | + |
| 139 | + # Get the weight and bias parameters from the conv or linear op. |
| 140 | + input_node_weight = get_param_tensor(self.exported_program, input_node.args[1]) |
| 141 | + input_node_weight_name = get_tensor_name( |
| 142 | + self.exported_program, input_node.args[1] |
| 143 | + ) |
| 144 | + assert input_node_weight is not None |
| 145 | + |
| 146 | + if has_bias_arg: |
| 147 | + input_node_bias = get_param_tensor( |
| 148 | + self.exported_program, input_node.args[2] |
| 149 | + ) |
| 150 | + input_node_bias_name = get_tensor_name( |
| 151 | + self.exported_program, input_node.args[2] |
| 152 | + ) |
| 153 | + else: |
| 154 | + input_node_bias = None |
| 155 | + input_node_bias_name = "" |
| 156 | + |
| 157 | + # Get the parameters from the batch_norm op. |
| 158 | + assert ( |
| 159 | + bn.target == exir_ops.edge.aten.native_batch_norm.default |
| 160 | + and len(bn.args) == 8 |
| 161 | + ) or ( |
| 162 | + bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default |
| 163 | + and len(bn.args) == 7 |
| 164 | + ) |
| 165 | + bn_weight = get_param_tensor(self.exported_program, bn.args[1]) |
| 166 | + bn_bias = get_param_tensor(self.exported_program, bn.args[2]) |
| 167 | + |
| 168 | + running_mean = get_param_tensor(self.exported_program, bn.args[3]) |
| 169 | + assert running_mean is not None |
| 170 | + |
| 171 | + running_var = get_param_tensor(self.exported_program, bn.args[4]) |
| 172 | + assert running_var is not None |
| 173 | + |
| 174 | + # args[7] for native_batch_norm, but args[6] for |
| 175 | + # _native_batch_norm_legit_no_training (which doesn't have training |
| 176 | + # as an arg). |
| 177 | + eps = bn.args[-1] |
| 178 | + |
| 179 | + # Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op. |
| 180 | + fuse_args = ( |
| 181 | + input_node_weight, |
| 182 | + input_node_bias, |
| 183 | + running_mean, |
| 184 | + running_var, |
| 185 | + eps, |
| 186 | + bn_weight, |
| 187 | + bn_bias, |
| 188 | + ) |
| 189 | + |
| 190 | + if is_conv: |
| 191 | + is_transpose = input_node.args[6] |
| 192 | + fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose) |
| 193 | + else: |
| 194 | + # Otherwise, this is a linear node. |
| 195 | + fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args) |
| 196 | + |
| 197 | + fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_") |
| 198 | + if input_node_bias_name == "": |
| 199 | + fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace( |
| 200 | + ".", "_" |
| 201 | + ) |
| 202 | + else: |
| 203 | + fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_") |
| 204 | + |
| 205 | + # Modify the graph by updating the weight and bias of the conv or linear op |
| 206 | + # with the fused weight and bias params, and replacing all the users |
| 207 | + # of getitem(batch_norm) with the conv or linear op. |
| 208 | + with graph.inserting_before(input_node.args[1]): |
| 209 | + fused_op_weight_node = create_constant_placeholder( |
| 210 | + exp_program=self.exported_program, |
| 211 | + graph=graph_module.graph, |
| 212 | + kind=InputKind.PARAMETER, |
| 213 | + name=fused_weight_name, |
| 214 | + data=fused_weight, |
| 215 | + ) |
| 216 | + if fused_bias is not None: |
| 217 | + fused_op_bias_node = create_constant_placeholder( |
| 218 | + exp_program=self.exported_program, |
| 219 | + graph=graph_module.graph, |
| 220 | + kind=InputKind.PARAMETER, |
| 221 | + name=fused_bias_name, |
| 222 | + data=fused_bias, |
| 223 | + ) |
| 224 | + else: |
| 225 | + fused_op_bias_node = None |
| 226 | + |
| 227 | + # Replace the original weight and bias with the fused batch_norm values. |
| 228 | + args = list(input_node.args) |
| 229 | + args[1] = fused_op_weight_node |
| 230 | + |
| 231 | + if has_bias_arg: |
| 232 | + # Overwrite original bias with the fused bias. |
| 233 | + args[2] = fused_op_bias_node |
| 234 | + elif fused_op_bias_node is not None: |
| 235 | + # Add the fused bias as a new argument if no bias had originally existed in the input_node. |
| 236 | + args.append(fused_op_bias_node) |
| 237 | + |
| 238 | + input_node.args = tuple(args) |
| 239 | + |
| 240 | + # Remove any use of batch_norm from the graph. |
| 241 | + for user in bn.users.copy(): |
| 242 | + assert user.target == operator.getitem |
| 243 | + user.replace_all_uses_with(input_node) |
| 244 | + graph.erase_node(user) |
| 245 | + |
| 246 | + graph.erase_node(bn) |
| 247 | + constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5]) |
0 commit comments