diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 6d4320a3a4c..6b2b61729ed 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -388,6 +388,7 @@ python_unittest( "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:fuse_ops", + "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:reorder_ops", diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index bc6153ac8cc..170c81f571e 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -144,6 +144,19 @@ def nodes_not_connected_in_gm( return True +# Returns the position of the first entry of a node of a given kind in the graph. +def get_node_pos( + graph_module: torch.fx.GraphModule, + target: torch.fx.Node, +) -> int: + pos = 0 + for node in graph_module.graph.nodes: + if node.target == target: + return pos + pos += 1 + return -1 + + # Returns true if there is no instance of a node with target succ_target # positioned immediately after a node with target pred_target in the graph def nodes_not_adjacent_in_gm( diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index c02917eb119..3e64a0ecd7c 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -11,85 +11,171 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.compiler import ( - export_to_edge, - quantize_and_export_to_cadence, -) + from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass +from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import ( count_node, get_compute_nodes_in_gm, + get_node_pos, nodes_not_adjacent_in_gm, nodes_not_connected_in_gm, ) from executorch.backends.cadence.aot.reorder_ops import ( + AdvanceQuantizeOpAboveDefChainPass, AdvanceQuantizeOpAboveDefInBranchPass, PostponeDequantizeOpBelowUseChainPass, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, + SinkOpsCloserToUsePass, ) from executorch.exir.dialects._ops import ops as exir_ops class TestReorderPasses(unittest.TestCase): def test_sink_dequantize(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(6, 12, bias=False) - - def forward(self, x, y): - x1 = self.linear(x) - y1 = self.linear(y) - x2 = torch.ops.aten.abs(x1) - return torch.ops.aten.cat((x2, y1)) - - inputs = (torch.randn(32, 6), torch.randn(32, 6)) - graph_module = ( - quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32)) + weights = builder.placeholder( + "weights", torch.randint(-128, 127, (6, 8), dtype=torch.int8) + ) + x_quantized = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 0.02252197265625, 20, -128, 127, torch.int8), + ) + y_quantized = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(y, 0.02181086875498295, -11, -128, 127, torch.int8), + ) + full = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -7), + ) + full_1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 1253324672), + ) + full_2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -3), + ) + full_3 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0.0), + ) + full_4 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -7), + ) + full_5 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 1290687488), + ) + full_6 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -3), + ) + full_7 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0.0), + ) + quantized_linear = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=(x_quantized, weights, full_3, 20, full_2, full_1, full, 13, None), ) + quantized_linear_1 = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=(y_quantized, weights, full_7, -11, full_6, full_5, full_4, 8, None), + ) + dequantize_per_tensor = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quantized_linear, 0.015294239856302738, 13, -128, 127, torch.int8), + ) + dequantize_per_tensor_1 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quantized_linear_1, 0.014382584020495415, 8, -128, 127, torch.int8), + ) + abs_1 = builder.call_operator( + op=exir_ops.edge.aten.abs.default, + args=(dequantize_per_tensor,), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([abs_1, dequantize_per_tensor_1],), + ) + builder.output(cat) + original_graph = builder.get_graph_module() + converted_graph = SinkOpsCloserToUsePass()(original_graph).graph_module + # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it self.assertTrue( nodes_not_adjacent_in_gm( - graph_module, + converted_graph, exir_ops.edge.aten.abs.default, exir_ops.edge.aten.cat.default, ), ) self.assertTrue( nodes_not_adjacent_in_gm( - graph_module, + converted_graph, exir_ops.edge.cadence.dequantize_per_tensor.default, exir_ops.edge.cadence.dequantize_per_tensor.default, ), ) def test_advance_branched_quantize(self): - class ReorderOpsBranch(torch.nn.Module): - def forward(self, x): - x = x.view((32, 6)) - x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1) - x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( - x1, 0.1, 10, 0, 255, torch.uint8 - ) - x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1) - x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( - x2, 0.1, 10, 0, 255, torch.uint8 - ) - x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1) - x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( - x3, 0.1, 10, 0, 255, torch.uint8 - ) - x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1) - x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( - x4, 0.2, 4, 0, 255, torch.uint8 - ) - return (x1, x2, x3, x4) + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(64, 3, dtype=torch.float32)) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [32, 6]), + ) + aten_slice_copy_tensor = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(view, 0, 0, 6), + ) + quantized_decomposed_quantize_per_tensor_default = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(aten_slice_copy_tensor, 0.1, 10, 0, 255, torch.uint8), + ) + + aten_slice_copy_tensor_1 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(view, 0, 6, 12), + ) + quantized_decomposed_quantize_per_tensor_default_1 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(aten_slice_copy_tensor_1, 0.1, 10, 0, 255, torch.uint8), + ) + + aten_slice_copy_tensor_2 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(view, 0, 12, 18), + ) + quantized_decomposed_quantize_per_tensor_default_2 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(aten_slice_copy_tensor_2, 0.1, 10, 0, 255, torch.uint8), + ) - model = ReorderOpsBranch() - X = torch.randn(64, 3) - graph_module = export_to_edge(model, (X,)).exported_program().graph_module + aten_slice_copy_tensor_3 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(view, 0, 18, 24), + ) + quantized_decomposed_quantize_per_tensor_default_3 = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(aten_slice_copy_tensor_3, 0.2, 4, 0, 255, torch.uint8), + ) + builder.output( + [ + quantized_decomposed_quantize_per_tensor_default, + quantized_decomposed_quantize_per_tensor_default_1, + quantized_decomposed_quantize_per_tensor_default_2, + quantized_decomposed_quantize_per_tensor_default_3, + ] + ) + original_graph = builder.get_graph_module() graph_module = AdvanceQuantizeOpAboveDefInBranchPass()( - graph_module + original_graph ).graph_module graph_module.graph.eliminate_dead_code() nodes = get_compute_nodes_in_gm(graph_module) @@ -135,104 +221,192 @@ def forward(self, x): @torch.no_grad() def test_advance_quantize(self): - class ReorderOpsChain(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(6, 12, bias=False) - - def forward(self, x): - x = x.permute([1, 0, 3, 2]) - x = self.linear(x) - return x - - model = ReorderOpsChain() - X = torch.randn(16, 1, 6, 32) - - graph_module = ( - quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32)) + weights = builder.placeholder( + "weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8) ) - # Assert that the quant node is no longer the successor of - # permute node. - self.assertTrue( - nodes_not_connected_in_gm( - graph_module, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.cadence.quantize_per_tensor.default, - ), + full = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -7), ) - # Assert that permute node is the successor of quant node - self.assertFalse( - nodes_not_connected_in_gm( - graph_module, - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.aten.permute_copy.default, + full_1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 1525501056), + ) + full_2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 2), + ) + full_3 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([12], 0.0), + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 0, 3, 2]), + ) + quantize_per_tensor = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(permute, 0.029049983248114586, -1, -128, 127, torch.int8), + ) + quantized_linear = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quantize_per_tensor, + weights, + full_3, + -1, + full_2, + full_1, + full, + -7, + None, ), ) - - def test_postpone_dequantize(self): - class ReorderOpsChain(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(6, 12, bias=False) - - def forward(self, x): - x = self.linear(x) - x = x.permute([1, 0, 3, 2]) - return x - - model = ReorderOpsChain() - X = torch.randn(1, 16, 32, 6) - - graph_module = ( - quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module + dequantize_per_tensor = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8), ) - # Assert that the dequant node is no longer the predecessor of the permute node + builder.output(dequantize_per_tensor) + original_graph = builder.get_graph_module() + converted_graph = AdvanceQuantizeOpAboveDefInBranchPass()( + original_graph + ).graph_module + converted_graph = AdvanceQuantizeOpAboveDefChainPass()( + original_graph + ).graph_module + # Assert that permute node is now the successor of the quant node. self.assertTrue( - nodes_not_connected_in_gm( - graph_module, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.aten.permute_copy.default, - ), + get_node_pos( + converted_graph, exir_ops.edge.cadence.quantize_per_tensor.default + ) + < get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) ) - # Assert that dequant node is the successor of permute node - self.assertFalse( - nodes_not_connected_in_gm( - graph_module, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, + + def test_postpone_dequantize1(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32)) + weights = builder.placeholder( + "weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8) + ) + full = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -7), + ) + full_1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 1461148032), + ) + full_2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], -4), + ) + full_3 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([12], 0.0), + ) + quantize_per_tensor = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 0.029049983248114586, -1, -128, 127, torch.int8), + ) + quantized_linear = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quantize_per_tensor, + weights, + full_3, + -8, + full_2, + full_1, + full, + 0, + None, ), ) + dequantize_per_tensor = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8), + ) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(dequantize_per_tensor, [1, 0, 3, 2]), + ) + builder.output(permute) + original_graph = builder.get_graph_module() + converted_graph = PostponeDequantizeOpBelowUseChainPass()( + original_graph + ).graph_module + # Assert that dequant node is now the successor of the permute node. + self.assertTrue( + get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) + < get_node_pos( + converted_graph, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ) + ) def test_postpone_dequantize_branched(self): - class ReorderOpsBranch(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 12, bias=False) - - def forward(self, x): - x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor( - x, 0.1, 10, 0, 255, torch.uint8 - ) - x0 = torch.squeeze(x0, 0) - x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1) - x1 = self.linear(x1) - - x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1) - x2 = self.linear(x2) + builder = GraphBuilder() + x = builder.placeholder( + "x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) + ) + p_linear_weight = builder.placeholder( + "weights", torch.randint(-128, 127, (3, 3), dtype=torch.int8) + ) + quantized_decomposed_dequantize_per_tensor_default = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(x, 0.1, 10, 0, 255, torch.uint8), + ) + aten_squeeze_copy_dims = builder.call_operator( + op=exir_ops.edge.aten.squeeze_copy.dims, + args=(quantized_decomposed_dequantize_per_tensor_default, [0]), + ) - x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1) - x3 = self.linear(x3) + aten_slice_copy_tensor = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(aten_squeeze_copy_dims, 0, 0, 6), + ) + aten_permute_copy_default = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p_linear_weight, [1, 0]), + ) + aten_mm_default = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(aten_slice_copy_tensor, aten_permute_copy_default), + ) - return (x1, x2, x3) + aten_slice_copy_tensor_1 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(aten_squeeze_copy_dims, 0, 6, 12), + ) + aten_permute_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p_linear_weight, [1, 0]), + ) + aten_mm_default_1 = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(aten_slice_copy_tensor_1, aten_permute_copy_default_1), + ) - model = ReorderOpsBranch() - X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) - graph_module = export_to_edge(model, (X,)).exported_program().graph_module + aten_slice_copy_tensor_2 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(aten_squeeze_copy_dims, 0, 12, 18), + ) + aten_permute_copy_default_2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(p_linear_weight, [1, 0]), + ) + aten_mm_default_2 = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(aten_slice_copy_tensor_2, aten_permute_copy_default_2), + ) + builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2]) + original_graph = builder.get_graph_module() graph_module = PostponeDequantizeOpBelowUseChainPass()( - graph_module + original_graph ).graph_module graph_module.graph.eliminate_dead_code() - # Asset that the dequant node was split into 4, one per branch self.assertEqual( count_node( @@ -261,31 +435,35 @@ def forward(self, x): # 4d -> permute -> 4d -> view -> 3d def test_permute3_view4_chains(self): - class PermuteViewChain(torch.nn.Module): - def forward(self, x): - # x is [3, 1, 768] - x = x.view((3, 12, 64)) - # x is [3, 12, 64] - x = x.permute([1, 0, 2]) - # x is [12, 3, 64] - x = x.view((1, 12, 3, 64)) - # x is [1, 12, 3, 64] - x = x.permute([0, 1, 3, 2]) - # x is [1, 12, 64, 3] - return x - - model = PermuteViewChain() - X = torch.randn(3, 1, 768) - graph_module = export_to_edge(model, (X,)).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 768)) + aten_view_copy_default = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [3, 12, 64]), + ) + aten_permute_copy_default = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default, [1, 0, 2]), + ) + aten_view_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(aten_permute_copy_default, [1, 12, 3, 64]), + ) + aten_permute_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default_1, [0, 1, 3, 2]), + ) + builder.output( + aten_permute_copy_default_1, + ) + original_graph = builder.get_graph_module() # Performing transform - graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - graph_module + converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + original_graph ).graph_module - graph_module.graph.eliminate_dead_code() - + converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute - nodes = get_compute_nodes_in_gm(graph_module) + nodes = get_compute_nodes_in_gm(converted_graph) self.assertEqual(len(nodes), 4) self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) @@ -294,31 +472,36 @@ def forward(self, x): # 3d -> permute -> 3d -> view -> 4d def test_permute4_view3_chains(self): - class PermuteViewChain(torch.nn.Module): - def forward(self, x): - # x is [3, 1, 768] - x = x.view((1, 3, 12, 64)) - # x is [1, 3, 12, 64] - x = x.permute([3, 1, 0, 2]) - # x is [64, 3, 1, 12] - x = x.view((64, 3, 12)) - # x is [64, 3, 12] - x = x.permute([2, 1, 0]) - # x is [12, 3, 64] - return x - - model = PermuteViewChain() - X = torch.randn(3, 1, 768) - graph_module = export_to_edge(model, (X,)).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 768)) + aten_view_copy_default = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 3, 12, 64]), + ) + aten_permute_copy_default = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default, [3, 1, 0, 2]), + ) + aten_view_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(aten_permute_copy_default, [64, 3, 12]), + ) + aten_permute_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default_1, [2, 1, 0]), + ) + builder.output( + aten_permute_copy_default_1, + ) + original_graph = builder.get_graph_module() # Performing transform - graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - graph_module + converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + original_graph ).graph_module - graph_module.graph.eliminate_dead_code() + converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute - nodes = get_compute_nodes_in_gm(graph_module) + nodes = get_compute_nodes_in_gm(converted_graph) self.assertEqual(len(nodes), 4) self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) @@ -329,31 +512,36 @@ def forward(self, x): # permute->4d->view->3d where the view not only removes the dimension whose # size is 1 (this is ok), but also changes the size of the dimensions (not ok). def test_permute_view_chains_neg(self): - class PermuteViewChain(torch.nn.Module): - def forward(self, x): - # x is [3, 1, 768] - x = x.view((1, 3, 12, 64)) - # x is [1, 3, 12, 64] - x = x.permute([3, 1, 0, 2]) - # x is [64, 3, 1, 12] - x = x.view((64, 6, 6)) - # x is [64, 6, 6] - x = x.permute([2, 1, 0]) - # x is [6, 6, 64] - return x - - model = PermuteViewChain() - X = torch.randn(3, 1, 768) - graph_module = export_to_edge(model, (X,)).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 768)) + aten_view_copy_default = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 3, 12, 64]), + ) + aten_permute_copy_default = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default, [3, 1, 0, 2]), + ) + aten_view_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(aten_permute_copy_default, [64, 6, 6]), + ) + aten_permute_copy_default_1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(aten_view_copy_default_1, [2, 1, 0]), + ) + builder.output( + aten_permute_copy_default_1, + ) + original_graph = builder.get_graph_module() # Performing transform (nothing should happen) - graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - graph_module + converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( + original_graph ).graph_module - graph_module.graph.eliminate_dead_code() + converted_graph.graph.eliminate_dead_code() # Assert the order is still view, permute, view, permute - nodes = get_compute_nodes_in_gm(graph_module) + nodes = get_compute_nodes_in_gm(converted_graph) self.assertEqual(len(nodes), 4) self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)