Skip to content

Commit 0984e76

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Eliminate squeeze->op->unsqueeze patterns. (#12380)
Summary: Pre-partitioner pattern to eliminate squeeze->[elementwise ops+slice]->unsqueeze patterns in the graph. For context, this is a pattern seen in MicroGestures model. Also postpone PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView after quant/dequant fusion since it enables more patterns to be fused. Differential Revision: D78104324
1 parent 3afd18d commit 0984e76

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,117 @@ def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
707707
return cast(list[int], permute_node.kwargs["dim"])
708708

709709

710+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
711+
class RemoveSqueezeUnsqueezeAroundElementwiseOps(ExportPass):
712+
"""
713+
Looks for subgraphs of the form:
714+
unsqueeze -> [op] -> squeeze
715+
and removes the unsqueeze and squeeze nodes by reshaping the intermediate ops. Only
716+
handles simple chain of ops as intermediate for now.
717+
718+
The pass works on view ops instead of unsqueeze and squeeze directly, thus it
719+
should be run after the squeeze/unsqueeze->view lowering.
720+
"""
721+
722+
intermediate_ops: set[EdgeOpOverload] = {
723+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
724+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
725+
exir_ops.edge.cadence.quantize_per_tensor.default,
726+
exir_ops.edge.cadence.dequantize_per_tensor.default,
727+
# Ops that require special handling:
728+
exir_ops.edge.aten.slice_copy.Tensor,
729+
}
730+
731+
def find_unsqueeze_dim(self, view_node: Node) -> Optional[int]:
732+
"""
733+
Return the unsqueeze dim if the given view_copy op unsqueezes the input tensor,
734+
if not return None.
735+
"""
736+
input_node = cast(Node, get_arg(view_node, 0, "input"))
737+
input_shape = input_node.meta["val"].shape
738+
output_shape = view_node.meta["val"].shape
739+
if len(output_shape) != len(input_shape) + 1:
740+
return None
741+
for dim in range(len(output_shape)):
742+
if output_shape == input_shape[:dim] + (1,) + input_shape[dim:]:
743+
return dim
744+
return None
745+
746+
def find_ancestor_squeeze(self, node: Node, squeeze_dim: int) -> Optional[Node]:
747+
"""
748+
Traverse up from the given node until finding a squeeze node with the given
749+
squeeze_dim. If no such node is found, return None.
750+
"""
751+
while True:
752+
# Only handle simple chains for now
753+
if len(node.users) != 1:
754+
return None
755+
if node.target in self.intermediate_ops:
756+
node = cast(Node, get_arg(node, 0, "input"))
757+
elif node.target == exir_ops.edge.aten.view_copy.default:
758+
input_node = cast(Node, get_arg(node, 0, "input"))
759+
input_shape = input_node.meta["val"].shape
760+
output_shape = node.meta["val"].shape
761+
# Check if the node is a squeeze op.
762+
if (
763+
len(input_shape) != len(output_shape) + 1
764+
or input_shape
765+
!= output_shape[:squeeze_dim] + (1,) + output_shape[squeeze_dim:]
766+
):
767+
return None
768+
return node
769+
else:
770+
return None
771+
772+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
773+
changed = False
774+
775+
# Traverse the graph looking for unsqueeze-like view ops.
776+
for node in graph_module.graph.find_nodes(
777+
op="call_function", target=exir_ops.edge.aten.view_copy.default
778+
):
779+
unsqueeze_dim = self.find_unsqueeze_dim(node)
780+
if unsqueeze_dim is None:
781+
continue
782+
783+
input_node = cast(Node, get_arg(node, 0, "input"))
784+
squeeze_node = self.find_ancestor_squeeze(input_node, unsqueeze_dim)
785+
if squeeze_node is None:
786+
continue
787+
788+
# Chain is found. Remove view ops and update the intermediate ops traversing
789+
# the chain.
790+
assert len(squeeze_node.users) == 1
791+
node = next(iter(squeeze_node.users))
792+
793+
# Skip first view_copy.
794+
squeeze_node.replace_all_uses_with(
795+
cast(Node, get_arg(squeeze_node, 0, "input"))
796+
)
797+
798+
# Go down the chain and update the intermediate ops if needed.
799+
while node.target != exir_ops.edge.aten.view_copy.default:
800+
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
801+
slice_dim = cast(int, get_arg(node, 1, "dim", default=0))
802+
if slice_dim < 0:
803+
slice_dim += len(node.meta["val"].shape)
804+
if slice_dim >= unsqueeze_dim:
805+
set_arg(node, 1, "dim", slice_dim + 1)
806+
assert len(node.users) == 1
807+
node = next(iter(node.users))
808+
809+
# Skip final view_copy.
810+
node.replace_all_uses_with(cast(Node, get_arg(node, 0, "input")))
811+
812+
changed = True
813+
814+
if changed:
815+
graph_module.graph.eliminate_dead_code()
816+
graph_module.recompile()
817+
818+
return PassResult(graph_module, changed)
819+
820+
710821
@register_cadence_pass(CadencePassAttribute(opt_level=1))
711822
class RemoveBranchedQuantDequant(ExportPass):
712823
"""

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import unittest
11+
from copy import deepcopy
1112
from typing import cast, List, Tuple
1213

1314
import executorch.backends.cadence.aot.ops_registrations # noqa
@@ -30,6 +31,7 @@
3031
RemoveNopSelectOpPass,
3132
RemoveNopSliceOrViewOpPass,
3233
RemovePermutesAroundElementwiseOps,
34+
RemoveSqueezeUnsqueezeAroundElementwiseOps,
3335
RemoveToOpsPass,
3436
RemoveZeroSizedCatArgsPass,
3537
RemoveZeroSizedConstantPadNd,
@@ -569,6 +571,53 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None:
569571
self.assertEqual(len(slices), 1)
570572
self.assertEqual(slices[0].args[1], 2)
571573

574+
def test_remove_squeeze_unsqueeze_around_elemwise_ops(self) -> None:
575+
builder = GraphBuilder()
576+
x = builder.placeholder("x", torch.randn(8, 1, 4, 4))
577+
squeeze = builder.call_operator(
578+
op=exir_ops.edge.aten.view_copy.default,
579+
args=(x, [8, 4, 4]),
580+
)
581+
quantize = builder.call_operator(
582+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
583+
args=(squeeze, 0.12, -4, -128, 127, torch.int8),
584+
)
585+
slice_copy = builder.call_operator(
586+
op=exir_ops.edge.aten.slice_copy.Tensor,
587+
args=(quantize, 1, 0, 2, 1),
588+
)
589+
unsqueeze = builder.call_operator(
590+
op=exir_ops.edge.aten.view_copy.default,
591+
args=(slice_copy, [8, 1, 2, 4]),
592+
)
593+
builder.output([unsqueeze])
594+
model = builder.get_graph_module()
595+
original = deepcopy(model)
596+
597+
p = RemoveSqueezeUnsqueezeAroundElementwiseOps()
598+
transformed = cast(PassResult, p(model)).graph_module
599+
600+
# No views should remain.
601+
self.assertEqual(
602+
count_node(transformed, exir_ops.edge.aten.view_copy.default), 0
603+
)
604+
605+
# Verify that slice dimension was updated correctly.
606+
slices = transformed.graph.find_nodes(
607+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
608+
)
609+
self.assertEqual(len(slices), 1)
610+
self.assertEqual(slices[0].args[1], 2)
611+
612+
# Verify the output of the model is the same as the original.
613+
sample_input = torch.randn(8, 1, 4, 4)
614+
self.assertTrue(
615+
torch.allclose(
616+
original(sample_input)[0],
617+
transformed(sample_input)[0],
618+
)
619+
)
620+
572621
def test_remove_permutes_around_elemwise_ops_mul(self) -> None:
573622
builder = GraphBuilder()
574623
x = builder.placeholder("x", torch.randn(2, 4, 4, 8))

0 commit comments

Comments
 (0)