Skip to content

Use GraphBuilder in reorder unit tests. #11103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
@@ -388,6 +388,7 @@ python_unittest(
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:reorder_ops",
13 changes: 13 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
@@ -144,6 +144,19 @@ def nodes_not_connected_in_gm(
return True


# Returns the position of the first entry of a node of a given kind in the graph.
def get_node_pos(
graph_module: torch.fx.GraphModule,
target: torch.fx.Node,
) -> int:
pos = 0
for node in graph_module.graph.nodes:
if node.target == target:
return pos
pos += 1
return -1


# Returns true if there is no instance of a node with target succ_target
# positioned immediately after a node with target pred_target in the graph
def nodes_not_adjacent_in_gm(
566 changes: 377 additions & 189 deletions backends/cadence/aot/tests/test_reorder_ops_passes.py
Original file line number Diff line number Diff line change
@@ -11,85 +11,171 @@

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.compiler import (
export_to_edge,
quantize_and_export_to_cadence,
)

from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.pass_utils import (
count_node,
get_compute_nodes_in_gm,
get_node_pos,
nodes_not_adjacent_in_gm,
nodes_not_connected_in_gm,
)
from executorch.backends.cadence.aot.reorder_ops import (
AdvanceQuantizeOpAboveDefChainPass,
AdvanceQuantizeOpAboveDefInBranchPass,
PostponeDequantizeOpBelowUseChainPass,
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
SinkOpsCloserToUsePass,
)
from executorch.exir.dialects._ops import ops as exir_ops


class TestReorderPasses(unittest.TestCase):
def test_sink_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(6, 12, bias=False)

def forward(self, x, y):
x1 = self.linear(x)
y1 = self.linear(y)
x2 = torch.ops.aten.abs(x1)
return torch.ops.aten.cat((x2, y1))

inputs = (torch.randn(32, 6), torch.randn(32, 6))
graph_module = (
quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32))
weights = builder.placeholder(
"weights", torch.randint(-128, 127, (6, 8), dtype=torch.int8)
)
x_quantized = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(x, 0.02252197265625, 20, -128, 127, torch.int8),
)
y_quantized = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(y, 0.02181086875498295, -11, -128, 127, torch.int8),
)
full = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -7),
)
full_1 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 1253324672),
)
full_2 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -3),
)
full_3 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 0.0),
)
full_4 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -7),
)
full_5 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 1290687488),
)
full_6 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -3),
)
full_7 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 0.0),
)
quantized_linear = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(x_quantized, weights, full_3, 20, full_2, full_1, full, 13, None),
)
quantized_linear_1 = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(y_quantized, weights, full_7, -11, full_6, full_5, full_4, 8, None),
)
dequantize_per_tensor = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quantized_linear, 0.015294239856302738, 13, -128, 127, torch.int8),
)
dequantize_per_tensor_1 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quantized_linear_1, 0.014382584020495415, 8, -128, 127, torch.int8),
)
abs_1 = builder.call_operator(
op=exir_ops.edge.aten.abs.default,
args=(dequantize_per_tensor,),
)
cat = builder.call_operator(
op=exir_ops.edge.aten.cat.default,
args=([abs_1, dequantize_per_tensor_1],),
)
builder.output(cat)
original_graph = builder.get_graph_module()
converted_graph = SinkOpsCloserToUsePass()(original_graph).graph_module

# Expect the SinkDequant pass to move dequant(y) from above the relu to just below it
self.assertTrue(
nodes_not_adjacent_in_gm(
graph_module,
converted_graph,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.cat.default,
),
)
self.assertTrue(
nodes_not_adjacent_in_gm(
graph_module,
converted_graph,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
),
)

def test_advance_branched_quantize(self):
class ReorderOpsBranch(torch.nn.Module):
def forward(self, x):
x = x.view((32, 6))
x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1)
x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
x1, 0.1, 10, 0, 255, torch.uint8
)
x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1)
x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
x2, 0.1, 10, 0, 255, torch.uint8
)
x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1)
x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
x3, 0.1, 10, 0, 255, torch.uint8
)
x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1)
x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
x4, 0.2, 4, 0, 255, torch.uint8
)
return (x1, x2, x3, x4)
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(64, 3, dtype=torch.float32))
view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [32, 6]),
)
aten_slice_copy_tensor = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(view, 0, 0, 6),
)
quantized_decomposed_quantize_per_tensor_default = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(aten_slice_copy_tensor, 0.1, 10, 0, 255, torch.uint8),
)

aten_slice_copy_tensor_1 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(view, 0, 6, 12),
)
quantized_decomposed_quantize_per_tensor_default_1 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(aten_slice_copy_tensor_1, 0.1, 10, 0, 255, torch.uint8),
)

aten_slice_copy_tensor_2 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(view, 0, 12, 18),
)
quantized_decomposed_quantize_per_tensor_default_2 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(aten_slice_copy_tensor_2, 0.1, 10, 0, 255, torch.uint8),
)

model = ReorderOpsBranch()
X = torch.randn(64, 3)
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
aten_slice_copy_tensor_3 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(view, 0, 18, 24),
)
quantized_decomposed_quantize_per_tensor_default_3 = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(aten_slice_copy_tensor_3, 0.2, 4, 0, 255, torch.uint8),
)
builder.output(
[
quantized_decomposed_quantize_per_tensor_default,
quantized_decomposed_quantize_per_tensor_default_1,
quantized_decomposed_quantize_per_tensor_default_2,
quantized_decomposed_quantize_per_tensor_default_3,
]
)
original_graph = builder.get_graph_module()
graph_module = AdvanceQuantizeOpAboveDefInBranchPass()(
graph_module
original_graph
).graph_module
graph_module.graph.eliminate_dead_code()
nodes = get_compute_nodes_in_gm(graph_module)
@@ -135,104 +221,192 @@ def forward(self, x):

@torch.no_grad()
def test_advance_quantize(self):
class ReorderOpsChain(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(6, 12, bias=False)

def forward(self, x):
x = x.permute([1, 0, 3, 2])
x = self.linear(x)
return x

model = ReorderOpsChain()
X = torch.randn(16, 1, 6, 32)

graph_module = (
quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32))
weights = builder.placeholder(
"weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8)
)
# Assert that the quant node is no longer the successor of
# permute node.
self.assertTrue(
nodes_not_connected_in_gm(
graph_module,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
),
full = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -7),
)
# Assert that permute node is the successor of quant node
self.assertFalse(
nodes_not_connected_in_gm(
graph_module,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.aten.permute_copy.default,
full_1 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 1525501056),
)
full_2 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 2),
)
full_3 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([12], 0.0),
)
permute = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [1, 0, 3, 2]),
)
quantize_per_tensor = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(permute, 0.029049983248114586, -1, -128, 127, torch.int8),
)
quantized_linear = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(
quantize_per_tensor,
weights,
full_3,
-1,
full_2,
full_1,
full,
-7,
None,
),
)

def test_postpone_dequantize(self):
class ReorderOpsChain(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(6, 12, bias=False)

def forward(self, x):
x = self.linear(x)
x = x.permute([1, 0, 3, 2])
return x

model = ReorderOpsChain()
X = torch.randn(1, 16, 32, 6)

graph_module = (
quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
dequantize_per_tensor = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8),
)
# Assert that the dequant node is no longer the predecessor of the permute node
builder.output(dequantize_per_tensor)
original_graph = builder.get_graph_module()
converted_graph = AdvanceQuantizeOpAboveDefInBranchPass()(
original_graph
).graph_module
converted_graph = AdvanceQuantizeOpAboveDefChainPass()(
original_graph
).graph_module
# Assert that permute node is now the successor of the quant node.
self.assertTrue(
nodes_not_connected_in_gm(
graph_module,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.aten.permute_copy.default,
),
get_node_pos(
converted_graph, exir_ops.edge.cadence.quantize_per_tensor.default
)
< get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)
)
# Assert that dequant node is the successor of permute node
self.assertFalse(
nodes_not_connected_in_gm(
graph_module,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,

def test_postpone_dequantize1(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32))
weights = builder.placeholder(
"weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8)
)
full = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -7),
)
full_1 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], 1461148032),
)
full_2 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([1], -4),
)
full_3 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([12], 0.0),
)
quantize_per_tensor = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(x, 0.029049983248114586, -1, -128, 127, torch.int8),
)
quantized_linear = builder.call_operator(
op=exir_ops.edge.cadence.quantized_linear.default,
args=(
quantize_per_tensor,
weights,
full_3,
-8,
full_2,
full_1,
full,
0,
None,
),
)
dequantize_per_tensor = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8),
)
permute = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(dequantize_per_tensor, [1, 0, 3, 2]),
)
builder.output(permute)
original_graph = builder.get_graph_module()
converted_graph = PostponeDequantizeOpBelowUseChainPass()(
original_graph
).graph_module
# Assert that dequant node is now the successor of the permute node.
self.assertTrue(
get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)
< get_node_pos(
converted_graph,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
)
)

def test_postpone_dequantize_branched(self):
class ReorderOpsBranch(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 12, bias=False)

def forward(self, x):
x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor(
x, 0.1, 10, 0, 255, torch.uint8
)
x0 = torch.squeeze(x0, 0)
x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1)
x1 = self.linear(x1)

x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1)
x2 = self.linear(x2)
builder = GraphBuilder()
x = builder.placeholder(
"x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8)
)
p_linear_weight = builder.placeholder(
"weights", torch.randint(-128, 127, (3, 3), dtype=torch.int8)
)
quantized_decomposed_dequantize_per_tensor_default = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(x, 0.1, 10, 0, 255, torch.uint8),
)
aten_squeeze_copy_dims = builder.call_operator(
op=exir_ops.edge.aten.squeeze_copy.dims,
args=(quantized_decomposed_dequantize_per_tensor_default, [0]),
)

x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1)
x3 = self.linear(x3)
aten_slice_copy_tensor = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(aten_squeeze_copy_dims, 0, 0, 6),
)
aten_permute_copy_default = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(p_linear_weight, [1, 0]),
)
aten_mm_default = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(aten_slice_copy_tensor, aten_permute_copy_default),
)

return (x1, x2, x3)
aten_slice_copy_tensor_1 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(aten_squeeze_copy_dims, 0, 6, 12),
)
aten_permute_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(p_linear_weight, [1, 0]),
)
aten_mm_default_1 = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(aten_slice_copy_tensor_1, aten_permute_copy_default_1),
)

model = ReorderOpsBranch()
X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8)
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
aten_slice_copy_tensor_2 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(aten_squeeze_copy_dims, 0, 12, 18),
)
aten_permute_copy_default_2 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(p_linear_weight, [1, 0]),
)
aten_mm_default_2 = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(aten_slice_copy_tensor_2, aten_permute_copy_default_2),
)
builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2])
original_graph = builder.get_graph_module()
graph_module = PostponeDequantizeOpBelowUseChainPass()(
graph_module
original_graph
).graph_module
graph_module.graph.eliminate_dead_code()

# Asset that the dequant node was split into 4, one per branch
self.assertEqual(
count_node(
@@ -261,31 +435,35 @@ def forward(self, x):

# 4d -> permute -> 4d -> view -> 3d
def test_permute3_view4_chains(self):
class PermuteViewChain(torch.nn.Module):
def forward(self, x):
# x is [3, 1, 768]
x = x.view((3, 12, 64))
# x is [3, 12, 64]
x = x.permute([1, 0, 2])
# x is [12, 3, 64]
x = x.view((1, 12, 3, 64))
# x is [1, 12, 3, 64]
x = x.permute([0, 1, 3, 2])
# x is [1, 12, 64, 3]
return x

model = PermuteViewChain()
X = torch.randn(3, 1, 768)
graph_module = export_to_edge(model, (X,)).exported_program().graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 1, 768))
aten_view_copy_default = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [3, 12, 64]),
)
aten_permute_copy_default = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default, [1, 0, 2]),
)
aten_view_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(aten_permute_copy_default, [1, 12, 3, 64]),
)
aten_permute_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default_1, [0, 1, 3, 2]),
)
builder.output(
aten_permute_copy_default_1,
)
original_graph = builder.get_graph_module()
# Performing transform
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
graph_module
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
original_graph
).graph_module
graph_module.graph.eliminate_dead_code()

converted_graph.graph.eliminate_dead_code()
# Assert the order becomes view, view, permute, permute
nodes = get_compute_nodes_in_gm(graph_module)
nodes = get_compute_nodes_in_gm(converted_graph)
self.assertEqual(len(nodes), 4)
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
@@ -294,31 +472,36 @@ def forward(self, x):

# 3d -> permute -> 3d -> view -> 4d
def test_permute4_view3_chains(self):
class PermuteViewChain(torch.nn.Module):
def forward(self, x):
# x is [3, 1, 768]
x = x.view((1, 3, 12, 64))
# x is [1, 3, 12, 64]
x = x.permute([3, 1, 0, 2])
# x is [64, 3, 1, 12]
x = x.view((64, 3, 12))
# x is [64, 3, 12]
x = x.permute([2, 1, 0])
# x is [12, 3, 64]
return x

model = PermuteViewChain()
X = torch.randn(3, 1, 768)
graph_module = export_to_edge(model, (X,)).exported_program().graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 1, 768))
aten_view_copy_default = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [1, 3, 12, 64]),
)
aten_permute_copy_default = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default, [3, 1, 0, 2]),
)
aten_view_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(aten_permute_copy_default, [64, 3, 12]),
)
aten_permute_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default_1, [2, 1, 0]),
)
builder.output(
aten_permute_copy_default_1,
)
original_graph = builder.get_graph_module()
# Performing transform
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
graph_module
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
original_graph
).graph_module
graph_module.graph.eliminate_dead_code()
converted_graph.graph.eliminate_dead_code()

# Assert the order becomes view, view, permute, permute
nodes = get_compute_nodes_in_gm(graph_module)
nodes = get_compute_nodes_in_gm(converted_graph)
self.assertEqual(len(nodes), 4)
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
@@ -329,31 +512,36 @@ def forward(self, x):
# permute->4d->view->3d where the view not only removes the dimension whose
# size is 1 (this is ok), but also changes the size of the dimensions (not ok).
def test_permute_view_chains_neg(self):
class PermuteViewChain(torch.nn.Module):
def forward(self, x):
# x is [3, 1, 768]
x = x.view((1, 3, 12, 64))
# x is [1, 3, 12, 64]
x = x.permute([3, 1, 0, 2])
# x is [64, 3, 1, 12]
x = x.view((64, 6, 6))
# x is [64, 6, 6]
x = x.permute([2, 1, 0])
# x is [6, 6, 64]
return x

model = PermuteViewChain()
X = torch.randn(3, 1, 768)
graph_module = export_to_edge(model, (X,)).exported_program().graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 1, 768))
aten_view_copy_default = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(x, [1, 3, 12, 64]),
)
aten_permute_copy_default = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default, [3, 1, 0, 2]),
)
aten_view_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(aten_permute_copy_default, [64, 6, 6]),
)
aten_permute_copy_default_1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(aten_view_copy_default_1, [2, 1, 0]),
)
builder.output(
aten_permute_copy_default_1,
)
original_graph = builder.get_graph_module()
# Performing transform (nothing should happen)
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
graph_module
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
original_graph
).graph_module
graph_module.graph.eliminate_dead_code()
converted_graph.graph.eliminate_dead_code()

# Assert the order is still view, permute, view, permute
nodes = get_compute_nodes_in_gm(graph_module)
nodes = get_compute_nodes_in_gm(converted_graph)
self.assertEqual(len(nodes), 4)
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)