Skip to content

Eliminate squeeze->op->unsqueeze patterns. #12380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,120 @@ def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
return cast(list[int], permute_node.kwargs["dim"])


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemoveSqueezeViewBeforeElementwiseOps(ExportPass):
"""
Looks for subgraphs of the form:
squeeze -> [elementwise ops] -> view
and removes the squeeze node by reshaping the intermediate ops. If the final view
is a corresponding unsqueeze it should also get eliminated by noop view elimination
later. Only handles simple chain of intermediates now.

The pass works on view ops instead of squeeze directly, thus it should be run after
the squeeze/unsqueeze->view lowering.
"""

intermediate_ops: set[EdgeOpOverload] = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
# Ops that require special handling:
exir_ops.edge.aten.slice_copy.Tensor,
}

def get_squeeze_indices(self, view_node: Node) -> List[int]:
"""
Returns the indices of the input dimensions that are squeezed in the output if
view node is a squeeze. Returns an empty list otherwise.
"""
input_node = cast(Node, get_arg(view_node, "input"))
input_shape = input_node.meta["val"].shape
output_shape = view_node.meta["val"].shape

if len(input_shape) <= len(output_shape):
return []

squeeze_indices = []
out_idx = 0
for idx, dim in enumerate(input_shape):
if dim == output_shape[out_idx]:
out_idx += 1
else:
# If there's a mismatch between the input and output dimensions, input
# dimension has to be 1.
if dim == 1:
squeeze_indices.append(idx)
else:
return []

# Check if all the output dimensions are consumed.
if out_idx != len(output_shape):
return []

return squeeze_indices

def update_slice_copy(self, slice_node: Node, squeeze_indices: List[int]) -> None:
"""
Updates the slice node to account for the squeeze indices removed.
"""
slice_rank = len(slice_node.meta["val"].shape)
slice_dim = cast(int, get_arg(slice_node, "dim"))
if slice_dim < 0:
slice_dim += slice_rank
for squeeze_dim in squeeze_indices:
if slice_dim >= squeeze_dim:
slice_dim += 1
set_arg(slice_node, "dim", slice_dim)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
visited_view_nodes = set()

for view_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True
):
if view_node in visited_view_nodes:
continue

squeeze_indices = self.get_squeeze_indices(view_node)
if not squeeze_indices:
continue

# Only handle simple chains for now
if len(view_node.users) != 1:
continue
node = next(iter(view_node.users))
intermediate_slices = []

# Traverse down from the node until finding another view op.
while node.target != exir_ops.edge.aten.view_copy.default:
# Only handle simple chains for now
if len(node.users) != 1:
break
if node.target not in self.intermediate_ops:
break
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
intermediate_slices.append(node)
node = next(iter(node.users))
else:
# View node found. We can't optimize this view_node again since the
# input shape is invalid now so add it to the visited set.
visited_view_nodes.add(node)

# Update the intermediate slices.
for slice_node in intermediate_slices:
self.update_slice_copy(slice_node, squeeze_indices)

# Skip the initial view node.
input_node = cast(Node, get_arg(view_node, "input"))
view_node.replace_all_uses_with(input_node)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return super().call(graph_module)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveBranchedQuantDequant(ExportPass):
"""
Expand Down
98 changes: 98 additions & 0 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import unittest
from copy import deepcopy
from typing import cast, List, Tuple

import executorch.backends.cadence.aot.ops_registrations # noqa
Expand All @@ -30,6 +31,7 @@
RemoveNopSelectOpPass,
RemoveNopSliceOrViewOpPass,
RemovePermutesAroundElementwiseOps,
RemoveSqueezeViewBeforeElementwiseOps,
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemoveZeroSizedConstantPadNd,
Expand Down Expand Up @@ -569,6 +571,102 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None:
self.assertEqual(len(slices), 1)
self.assertEqual(slices[0].args[1], 2)

def test_remove_squeeze_view_before_elemwise_ops(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 1, 4, 4))
squeeze = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [8, 4, 4]),
)
quantize = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(squeeze, 0.12, -4, -128, 127, torch.int8),
)
slice_copy = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(quantize, 1, 0, 2, 1),
)
unsqueeze = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(slice_copy, [8, 1, 2, 4]),
)
builder.output([unsqueeze])
model = builder.get_graph_module()
original = deepcopy(model)

p = RemoveSqueezeViewBeforeElementwiseOps()
transformed = cast(PassResult, p(model)).graph_module

# First view should be eliminated and second view should be trivial.
views = transformed.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.view_copy.default
)
self.assertEqual(len(views), 1)
self.assertEqual(views[0].args[0].meta["val"].shape, views[0].meta["val"].shape)

# Verify that slice dimension was updated correctly.
slices = transformed.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slices), 1)
self.assertEqual(slices[0].args[1], 2)

# Verify the output of the model is the same as the original.
sample_input = torch.randn(8, 1, 4, 4)
self.assertTrue(
torch.allclose(
original(sample_input)[0],
transformed(sample_input)[0],
)
)

def test_remove_squeeze_view_before_elemwise_ops_multiple_squeeze(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 1, 1, 4, 1, 4))
squeeze = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [8, 4, 4]),
)
quantize = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(squeeze, 0.12, -4, -128, 127, torch.int8),
)
slice_copy = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(quantize, 1, 0, 2, 1),
)
view_copy = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(slice_copy, [16, 4]),
)
builder.output([view_copy])
model = builder.get_graph_module()
original = deepcopy(model)

p = RemoveSqueezeViewBeforeElementwiseOps()
transformed = cast(PassResult, p(model)).graph_module

# First view should be eliminated.
self.assertEqual(
count_node(transformed, exir_ops.edge.aten.view_copy.default), 1
)

# Verify that slice dimension was updated correctly.
slices = transformed.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slices), 1)
self.assertEqual(slices[0].args[1], 3)

# Verify the output of the model is the same as the original.
sample_input = torch.randn(8, 1, 1, 4, 1, 4)
self.assertTrue(
torch.allclose(
original(sample_input)[0],
transformed(sample_input)[0],
)
)

def test_remove_permutes_around_elemwise_ops_mul(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 4, 4, 8))
Expand Down
Loading