diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index faee453346c..4721e5a1926 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -19,7 +19,7 @@ import logging from dataclasses import dataclass, field -from typing import cast, List, Optional, Sequence +from typing import cast, List, Optional, Sequence, Set import torch import torch.fx @@ -707,6 +707,118 @@ def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: return cast(list[int], permute_node.kwargs["dim"]) +@register_cadence_pass(CadencePassAttribute(opt_level=2)) +class RemoveSqueezeViewBeforeElementwiseOps(ExportPass): + """ + Looks for subgraphs of the form: + squeeze -> [elementwise ops] -> view + and removes the squeeze node by reshaping the intermediate ops. If the final view + is a corresponding unsqueeze it should also get eliminated by noop view elimination + later. Only handles simple chain of intermediates now. + + The pass works on view ops instead of squeeze directly, thus it should be run after + the squeeze/unsqueeze->view lowering. + """ + + intermediate_ops: set[EdgeOpOverload] = { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + # Ops that require special handling: + exir_ops.edge.aten.slice_copy.Tensor, + } + + def get_squeeze_indices(self, view_node: Node) -> List[int]: + """ + Returns the indices of the input dimensions that are squeezed in the output if + view node is a squeeze. Returns an empty list otherwise. + """ + input_node = cast(Node, get_arg(view_node, "input")) + input_shape = input_node.meta["val"].shape + output_shape = view_node.meta["val"].shape + + if len(input_shape) <= len(output_shape): + return [] + + squeeze_indices = [] + out_idx = 0 + for idx, dim in enumerate(input_shape): + if out_idx >= len(output_shape): + return [] + if dim == output_shape[out_idx]: + out_idx += 1 + else: + # If there's a mismatch between the input and output dimensions, input + # dimension has to be 1. + if dim == 1: + squeeze_indices.append(idx) + else: + return [] + + # Check if all the output dimensions are consumed. + if out_idx != len(output_shape): + return [] + + return squeeze_indices + + def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None: + if view_node in visited_view_nodes: + return + + squeeze_indices = self.get_squeeze_indices(view_node) + if not squeeze_indices: + return + + # Only handle simple chains for now. + if len(view_node.users) != 1: + return + node = next(iter(view_node.users)) + + # Traverse down from the node until finding another view op. + intermediate_slices = [] + while node.target != exir_ops.edge.aten.view_copy.default: + # Only handle simple chains for now + if len(node.users) != 1: + return + if node.target not in self.intermediate_ops: + return + if node.target == exir_ops.edge.aten.slice_copy.Tensor: + intermediate_slices.append(node) + node = next(iter(node.users)) + + # View node found. We can't optimize this view_node again since the + # input shape is invalid now so add it to the visited set. + visited_view_nodes.add(node) + + # Update the intermediate slices. + for slice_node in intermediate_slices: + slice_rank = len(slice_node.meta["val"].shape) + slice_dim = cast(int, get_arg(slice_node, "dim")) + if slice_dim < 0: + slice_dim += slice_rank + for squeeze_dim in squeeze_indices: + if slice_dim >= squeeze_dim: + slice_dim += 1 + set_arg(slice_node, "dim", slice_dim) + + # Skip the initial view node. + input_node = cast(Node, get_arg(view_node, "input")) + view_node.replace_all_uses_with(input_node) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + visited_view_nodes = set() + for view_node in graph_module.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True + ): + self.handle_squeeze(view_node, visited_view_nodes) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return super().call(graph_module) + + @register_cadence_pass(CadencePassAttribute(opt_level=1)) class RemoveBranchedQuantDequant(ExportPass): """ diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 5fe2848be94..a38416c0ff1 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -8,6 +8,7 @@ import unittest +from copy import deepcopy from typing import cast, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa @@ -30,6 +31,7 @@ RemoveNopSelectOpPass, RemoveNopSliceOrViewOpPass, RemovePermutesAroundElementwiseOps, + RemoveSqueezeViewBeforeElementwiseOps, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, RemoveZeroSizedConstantPadNd, @@ -569,6 +571,102 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None: self.assertEqual(len(slices), 1) self.assertEqual(slices[0].args[1], 2) + def test_remove_squeeze_view_before_elemwise_ops(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 1, 4, 4)) + squeeze = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [8, 4, 4]), + ) + quantize = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(squeeze, 0.12, -4, -128, 127, torch.int8), + ) + slice_copy = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(quantize, 1, 0, 2, 1), + ) + unsqueeze = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(slice_copy, [8, 1, 2, 4]), + ) + builder.output([unsqueeze]) + model = builder.get_graph_module() + original = deepcopy(model) + + p = RemoveSqueezeViewBeforeElementwiseOps() + transformed = cast(PassResult, p(model)).graph_module + + # First view should be eliminated and second view should be trivial. + views = transformed.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.view_copy.default + ) + self.assertEqual(len(views), 1) + self.assertEqual(views[0].args[0].meta["val"].shape, views[0].meta["val"].shape) + + # Verify that slice dimension was updated correctly. + slices = transformed.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slices), 1) + self.assertEqual(slices[0].args[1], 2) + + # Verify the output of the model is the same as the original. + sample_input = torch.randn(8, 1, 4, 4) + self.assertTrue( + torch.allclose( + original(sample_input)[0], + transformed(sample_input)[0], + ) + ) + + def test_remove_squeeze_view_before_elemwise_ops_multiple_squeeze(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 1, 1, 4, 1, 4)) + squeeze = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [8, 4, 4]), + ) + quantize = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(squeeze, 0.12, -4, -128, 127, torch.int8), + ) + slice_copy = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(quantize, 1, 0, 2, 1), + ) + view_copy = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(slice_copy, [16, 4]), + ) + builder.output([view_copy]) + model = builder.get_graph_module() + original = deepcopy(model) + + p = RemoveSqueezeViewBeforeElementwiseOps() + transformed = cast(PassResult, p(model)).graph_module + + # First view should be eliminated. + self.assertEqual( + count_node(transformed, exir_ops.edge.aten.view_copy.default), 1 + ) + + # Verify that slice dimension was updated correctly. + slices = transformed.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slices), 1) + self.assertEqual(slices[0].args[1], 3) + + # Verify the output of the model is the same as the original. + sample_input = torch.randn(8, 1, 1, 4, 1, 4) + self.assertTrue( + torch.allclose( + original(sample_input)[0], + transformed(sample_input)[0], + ) + ) + def test_remove_permutes_around_elemwise_ops_mul(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 4, 4, 8))