Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 97a95d7

Browse files
eigen-kfacebook-github-bot
authored andcommittedMay 23, 2025
Use GraphBuilder in reorder unit tests. (#11103)
Summary: Pull Request resolved: #11103 Use GraphBuilder in reorder unit tests. Reviewed By: zonglinpeng Differential Revision: D75257222
1 parent 4014cc6 commit 97a95d7

File tree

3 files changed

+390
-189
lines changed

3 files changed

+390
-189
lines changed
 

‎backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ python_unittest(
388388
"//caffe2:torch",
389389
"//executorch/backends/cadence/aot:compiler",
390390
"//executorch/backends/cadence/aot:fuse_ops",
391+
"//executorch/backends/cadence/aot:graph_builder",
391392
"//executorch/backends/cadence/aot:ops_registrations",
392393
"//executorch/backends/cadence/aot:pass_utils",
393394
"//executorch/backends/cadence/aot:reorder_ops",

‎backends/cadence/aot/pass_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@ def nodes_not_connected_in_gm(
144144
return True
145145

146146

147+
# Returns the position of the first entry of a node of a given kind in the graph.
148+
def get_node_pos(
149+
graph_module: torch.fx.GraphModule,
150+
target: torch.fx.Node,
151+
) -> int:
152+
pos = 0
153+
for node in graph_module.graph.nodes:
154+
if node.target == target:
155+
return pos
156+
pos += 1
157+
return -1
158+
159+
147160
# Returns true if there is no instance of a node with target succ_target
148161
# positioned immediately after a node with target pred_target in the graph
149162
def nodes_not_adjacent_in_gm(

‎backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 376 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -11,85 +11,171 @@
1111

1212
import executorch.backends.cadence.aot.ops_registrations # noqa
1313
import torch
14-
from executorch.backends.cadence.aot.compiler import (
15-
export_to_edge,
16-
quantize_and_export_to_cadence,
17-
)
14+
1815
from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
16+
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
1917
from executorch.backends.cadence.aot.pass_utils import (
2018
count_node,
2119
get_compute_nodes_in_gm,
20+
get_node_pos,
2221
nodes_not_adjacent_in_gm,
2322
nodes_not_connected_in_gm,
2423
)
2524
from executorch.backends.cadence.aot.reorder_ops import (
25+
AdvanceQuantizeOpAboveDefChainPass,
2626
AdvanceQuantizeOpAboveDefInBranchPass,
2727
PostponeDequantizeOpBelowUseChainPass,
2828
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
29+
SinkOpsCloserToUsePass,
2930
)
3031
from executorch.exir.dialects._ops import ops as exir_ops
3132

3233

3334
class TestReorderPasses(unittest.TestCase):
3435
def test_sink_dequantize(self):
35-
class M(torch.nn.Module):
36-
def __init__(self):
37-
super().__init__()
38-
self.linear = torch.nn.Linear(6, 12, bias=False)
39-
40-
def forward(self, x, y):
41-
x1 = self.linear(x)
42-
y1 = self.linear(y)
43-
x2 = torch.ops.aten.abs(x1)
44-
return torch.ops.aten.cat((x2, y1))
45-
46-
inputs = (torch.randn(32, 6), torch.randn(32, 6))
47-
graph_module = (
48-
quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module
36+
builder = GraphBuilder()
37+
x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32))
38+
y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32))
39+
weights = builder.placeholder(
40+
"weights", torch.randint(-128, 127, (6, 8), dtype=torch.int8)
41+
)
42+
x_quantized = builder.call_operator(
43+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
44+
args=(x, 0.02252197265625, 20, -128, 127, torch.int8),
45+
)
46+
y_quantized = builder.call_operator(
47+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
48+
args=(y, 0.02181086875498295, -11, -128, 127, torch.int8),
49+
)
50+
full = builder.call_operator(
51+
op=exir_ops.edge.aten.full.default,
52+
args=([1], -7),
53+
)
54+
full_1 = builder.call_operator(
55+
op=exir_ops.edge.aten.full.default,
56+
args=([1], 1253324672),
57+
)
58+
full_2 = builder.call_operator(
59+
op=exir_ops.edge.aten.full.default,
60+
args=([1], -3),
61+
)
62+
full_3 = builder.call_operator(
63+
op=exir_ops.edge.aten.full.default,
64+
args=([1], 0.0),
65+
)
66+
full_4 = builder.call_operator(
67+
op=exir_ops.edge.aten.full.default,
68+
args=([1], -7),
69+
)
70+
full_5 = builder.call_operator(
71+
op=exir_ops.edge.aten.full.default,
72+
args=([1], 1290687488),
73+
)
74+
full_6 = builder.call_operator(
75+
op=exir_ops.edge.aten.full.default,
76+
args=([1], -3),
77+
)
78+
full_7 = builder.call_operator(
79+
op=exir_ops.edge.aten.full.default,
80+
args=([1], 0.0),
81+
)
82+
quantized_linear = builder.call_operator(
83+
op=exir_ops.edge.cadence.quantized_linear.default,
84+
args=(x_quantized, weights, full_3, 20, full_2, full_1, full, 13, None),
4985
)
86+
quantized_linear_1 = builder.call_operator(
87+
op=exir_ops.edge.cadence.quantized_linear.default,
88+
args=(y_quantized, weights, full_7, -11, full_6, full_5, full_4, 8, None),
89+
)
90+
dequantize_per_tensor = builder.call_operator(
91+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
92+
args=(quantized_linear, 0.015294239856302738, 13, -128, 127, torch.int8),
93+
)
94+
dequantize_per_tensor_1 = builder.call_operator(
95+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
96+
args=(quantized_linear_1, 0.014382584020495415, 8, -128, 127, torch.int8),
97+
)
98+
abs_1 = builder.call_operator(
99+
op=exir_ops.edge.aten.abs.default,
100+
args=(dequantize_per_tensor,),
101+
)
102+
cat = builder.call_operator(
103+
op=exir_ops.edge.aten.cat.default,
104+
args=([abs_1, dequantize_per_tensor_1],),
105+
)
106+
builder.output(cat)
107+
original_graph = builder.get_graph_module()
108+
converted_graph = SinkOpsCloserToUsePass()(original_graph).graph_module
109+
50110
# Expect the SinkDequant pass to move dequant(y) from above the relu to just below it
51111
self.assertTrue(
52112
nodes_not_adjacent_in_gm(
53-
graph_module,
113+
converted_graph,
54114
exir_ops.edge.aten.abs.default,
55115
exir_ops.edge.aten.cat.default,
56116
),
57117
)
58118
self.assertTrue(
59119
nodes_not_adjacent_in_gm(
60-
graph_module,
120+
converted_graph,
61121
exir_ops.edge.cadence.dequantize_per_tensor.default,
62122
exir_ops.edge.cadence.dequantize_per_tensor.default,
63123
),
64124
)
65125

66126
def test_advance_branched_quantize(self):
67-
class ReorderOpsBranch(torch.nn.Module):
68-
def forward(self, x):
69-
x = x.view((32, 6))
70-
x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1)
71-
x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
72-
x1, 0.1, 10, 0, 255, torch.uint8
73-
)
74-
x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1)
75-
x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
76-
x2, 0.1, 10, 0, 255, torch.uint8
77-
)
78-
x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1)
79-
x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
80-
x3, 0.1, 10, 0, 255, torch.uint8
81-
)
82-
x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1)
83-
x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
84-
x4, 0.2, 4, 0, 255, torch.uint8
85-
)
86-
return (x1, x2, x3, x4)
127+
builder = GraphBuilder()
128+
x = builder.placeholder("x", torch.randn(64, 3, dtype=torch.float32))
129+
view = builder.call_operator(
130+
op=exir_ops.edge.aten.view_copy.default,
131+
args=(x, [32, 6]),
132+
)
133+
aten_slice_copy_tensor = builder.call_operator(
134+
op=exir_ops.edge.aten.slice_copy.Tensor,
135+
args=(view, 0, 0, 6),
136+
)
137+
quantized_decomposed_quantize_per_tensor_default = builder.call_operator(
138+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
139+
args=(aten_slice_copy_tensor, 0.1, 10, 0, 255, torch.uint8),
140+
)
87141

88-
model = ReorderOpsBranch()
89-
X = torch.randn(64, 3)
90-
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
142+
aten_slice_copy_tensor_1 = builder.call_operator(
143+
op=exir_ops.edge.aten.slice_copy.Tensor,
144+
args=(view, 0, 6, 12),
145+
)
146+
quantized_decomposed_quantize_per_tensor_default_1 = builder.call_operator(
147+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
148+
args=(aten_slice_copy_tensor_1, 0.1, 10, 0, 255, torch.uint8),
149+
)
150+
151+
aten_slice_copy_tensor_2 = builder.call_operator(
152+
op=exir_ops.edge.aten.slice_copy.Tensor,
153+
args=(view, 0, 12, 18),
154+
)
155+
quantized_decomposed_quantize_per_tensor_default_2 = builder.call_operator(
156+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
157+
args=(aten_slice_copy_tensor_2, 0.1, 10, 0, 255, torch.uint8),
158+
)
159+
160+
aten_slice_copy_tensor_3 = builder.call_operator(
161+
op=exir_ops.edge.aten.slice_copy.Tensor,
162+
args=(view, 0, 18, 24),
163+
)
164+
quantized_decomposed_quantize_per_tensor_default_3 = builder.call_operator(
165+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
166+
args=(aten_slice_copy_tensor_3, 0.2, 4, 0, 255, torch.uint8),
167+
)
168+
builder.output(
169+
[
170+
quantized_decomposed_quantize_per_tensor_default,
171+
quantized_decomposed_quantize_per_tensor_default_1,
172+
quantized_decomposed_quantize_per_tensor_default_2,
173+
quantized_decomposed_quantize_per_tensor_default_3,
174+
]
175+
)
176+
original_graph = builder.get_graph_module()
91177
graph_module = AdvanceQuantizeOpAboveDefInBranchPass()(
92-
graph_module
178+
original_graph
93179
).graph_module
94180
graph_module.graph.eliminate_dead_code()
95181
nodes = get_compute_nodes_in_gm(graph_module)
@@ -135,104 +221,191 @@ def forward(self, x):
135221

136222
@torch.no_grad()
137223
def test_advance_quantize(self):
138-
class ReorderOpsChain(torch.nn.Module):
139-
def __init__(self):
140-
super().__init__()
141-
self.linear = torch.nn.Linear(6, 12, bias=False)
142-
143-
def forward(self, x):
144-
x = x.permute([1, 0, 3, 2])
145-
x = self.linear(x)
146-
return x
147-
148-
model = ReorderOpsChain()
149-
X = torch.randn(16, 1, 6, 32)
150-
151-
graph_module = (
152-
quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
224+
builder = GraphBuilder()
225+
x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32))
226+
weights = builder.placeholder(
227+
"weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8)
153228
)
154-
# Assert that the quant node is no longer the successor of
155-
# permute node.
156-
self.assertTrue(
157-
nodes_not_connected_in_gm(
158-
graph_module,
159-
exir_ops.edge.aten.permute_copy.default,
160-
exir_ops.edge.cadence.quantize_per_tensor.default,
161-
),
229+
full = builder.call_operator(
230+
op=exir_ops.edge.aten.full.default,
231+
args=([1], -7),
162232
)
163-
# Assert that permute node is the successor of quant node
164-
self.assertFalse(
165-
nodes_not_connected_in_gm(
166-
graph_module,
167-
exir_ops.edge.cadence.quantize_per_tensor.default,
168-
exir_ops.edge.aten.permute_copy.default,
233+
full_1 = builder.call_operator(
234+
op=exir_ops.edge.aten.full.default,
235+
args=([1], 1525501056),
236+
)
237+
full_2 = builder.call_operator(
238+
op=exir_ops.edge.aten.full.default,
239+
args=([1], 2),
240+
)
241+
full_3 = builder.call_operator(
242+
op=exir_ops.edge.aten.full.default,
243+
args=([12], 0.0),
244+
)
245+
permute = builder.call_operator(
246+
op=exir_ops.edge.aten.permute_copy.default,
247+
args=(x, [1, 0, 3, 2]),
248+
)
249+
quantize_per_tensor = builder.call_operator(
250+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
251+
args=(permute, 0.029049983248114586, -1, -128, 127, torch.int8),
252+
)
253+
quantized_linear = builder.call_operator(
254+
op=exir_ops.edge.cadence.quantized_linear.default,
255+
args=(
256+
quantize_per_tensor,
257+
weights,
258+
full_3,
259+
-1,
260+
full_2,
261+
full_1,
262+
full,
263+
-7,
264+
None,
169265
),
170266
)
171-
172-
def test_postpone_dequantize(self):
173-
class ReorderOpsChain(torch.nn.Module):
174-
def __init__(self):
175-
super().__init__()
176-
self.linear = torch.nn.Linear(6, 12, bias=False)
177-
178-
def forward(self, x):
179-
x = self.linear(x)
180-
x = x.permute([1, 0, 3, 2])
181-
return x
182-
183-
model = ReorderOpsChain()
184-
X = torch.randn(1, 16, 32, 6)
185-
186-
graph_module = (
187-
quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
267+
dequantize_per_tensor = builder.call_operator(
268+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
269+
args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8),
188270
)
189-
# Assert that the dequant node is no longer the predecessor of the permute node
271+
builder.output(dequantize_per_tensor)
272+
original_graph = builder.get_graph_module()
273+
converted_graph = AdvanceQuantizeOpAboveDefInBranchPass()(
274+
original_graph
275+
).graph_module
276+
converted_graph = AdvanceQuantizeOpAboveDefChainPass()(
277+
original_graph
278+
).graph_module
279+
# Assert that permute node is now the successor of the quant node.
190280
self.assertTrue(
191-
nodes_not_connected_in_gm(
192-
graph_module,
193-
exir_ops.edge.cadence.dequantize_per_tensor.default,
194-
exir_ops.edge.aten.permute_copy.default,
195-
),
281+
get_node_pos(
282+
converted_graph, exir_ops.edge.cadence.quantize_per_tensor.default
283+
)
284+
< get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)
196285
)
197-
# Assert that dequant node is the successor of permute node
198-
self.assertFalse(
199-
nodes_not_connected_in_gm(
200-
graph_module,
201-
exir_ops.edge.aten.permute_copy.default,
202-
exir_ops.edge.cadence.dequantize_per_tensor.default,
286+
287+
def test_postpone_dequantize1(self):
288+
builder = GraphBuilder()
289+
x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32))
290+
weights = builder.placeholder(
291+
"weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8)
292+
)
293+
full = builder.call_operator(
294+
op=exir_ops.edge.aten.full.default,
295+
args=([1], -7),
296+
)
297+
full_1 = builder.call_operator(
298+
op=exir_ops.edge.aten.full.default,
299+
args=([1], 1461148032),
300+
)
301+
full_2 = builder.call_operator(
302+
op=exir_ops.edge.aten.full.default,
303+
args=([1], -4),
304+
)
305+
full_3 = builder.call_operator(
306+
op=exir_ops.edge.aten.full.default,
307+
args=([12], 0.0),
308+
)
309+
quantize_per_tensor = builder.call_operator(
310+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
311+
args=(x, 0.029049983248114586, -1, -128, 127, torch.int8),
312+
)
313+
quantized_linear = builder.call_operator(
314+
op=exir_ops.edge.cadence.quantized_linear.default,
315+
args=(
316+
quantize_per_tensor,
317+
weights,
318+
full_3,
319+
-8,
320+
full_2,
321+
full_1,
322+
full,
323+
0,
324+
None,
203325
),
204326
)
327+
dequantize_per_tensor = builder.call_operator(
328+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
329+
args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8),
330+
)
331+
permute = builder.call_operator(
332+
op=exir_ops.edge.aten.permute_copy.default,
333+
args=(dequantize_per_tensor, [1, 0, 3, 2]),
334+
)
335+
builder.output(permute)
336+
original_graph = builder.get_graph_module()
337+
converted_graph = PostponeDequantizeOpBelowUseChainPass()(
338+
original_graph
339+
).graph_module
340+
# Assert that dequant node is now the successor of the permute node.
341+
self.assertTrue(
342+
get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)
343+
< get_node_pos(
344+
converted_graph, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
345+
)
346+
)
205347

206348
def test_postpone_dequantize_branched(self):
207-
class ReorderOpsBranch(torch.nn.Module):
208-
def __init__(self):
209-
super().__init__()
210-
self.linear = torch.nn.Linear(3, 12, bias=False)
211-
212-
def forward(self, x):
213-
x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor(
214-
x, 0.1, 10, 0, 255, torch.uint8
215-
)
216-
x0 = torch.squeeze(x0, 0)
217-
x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1)
218-
x1 = self.linear(x1)
219-
220-
x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1)
221-
x2 = self.linear(x2)
349+
builder = GraphBuilder()
350+
x = builder.placeholder(
351+
"x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8)
352+
)
353+
p_linear_weight = builder.placeholder(
354+
"weights", torch.randint(-128, 127, (3, 3), dtype=torch.int8)
355+
)
356+
quantized_decomposed_dequantize_per_tensor_default = builder.call_operator(
357+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
358+
args=(x, 0.1, 10, 0, 255, torch.uint8),
359+
)
360+
aten_squeeze_copy_dims = builder.call_operator(
361+
op=exir_ops.edge.aten.squeeze_copy.dims,
362+
args=(quantized_decomposed_dequantize_per_tensor_default, [0]),
363+
)
222364

223-
x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1)
224-
x3 = self.linear(x3)
365+
aten_slice_copy_tensor = builder.call_operator(
366+
op=exir_ops.edge.aten.slice_copy.Tensor,
367+
args=(aten_squeeze_copy_dims, 0, 0, 6),
368+
)
369+
aten_permute_copy_default = builder.call_operator(
370+
op=exir_ops.edge.aten.permute_copy.default,
371+
args=(p_linear_weight, [1, 0]),
372+
)
373+
aten_mm_default = builder.call_operator(
374+
op=exir_ops.edge.aten.mm.default,
375+
args=(aten_slice_copy_tensor, aten_permute_copy_default),
376+
)
225377

226-
return (x1, x2, x3)
378+
aten_slice_copy_tensor_1 = builder.call_operator(
379+
op=exir_ops.edge.aten.slice_copy.Tensor,
380+
args=(aten_squeeze_copy_dims, 0, 6, 12),
381+
)
382+
aten_permute_copy_default_1 = builder.call_operator(
383+
op=exir_ops.edge.aten.permute_copy.default,
384+
args=(p_linear_weight, [1, 0]),
385+
)
386+
aten_mm_default_1 = builder.call_operator(
387+
op=exir_ops.edge.aten.mm.default,
388+
args=(aten_slice_copy_tensor_1, aten_permute_copy_default_1),
389+
)
227390

228-
model = ReorderOpsBranch()
229-
X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8)
230-
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
391+
aten_slice_copy_tensor_2 = builder.call_operator(
392+
op=exir_ops.edge.aten.slice_copy.Tensor,
393+
args=(aten_squeeze_copy_dims, 0, 12, 18),
394+
)
395+
aten_permute_copy_default_2 = builder.call_operator(
396+
op=exir_ops.edge.aten.permute_copy.default,
397+
args=(p_linear_weight, [1, 0]),
398+
)
399+
aten_mm_default_2 = builder.call_operator(
400+
op=exir_ops.edge.aten.mm.default,
401+
args=(aten_slice_copy_tensor_2, aten_permute_copy_default_2),
402+
)
403+
builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2])
404+
original_graph = builder.get_graph_module()
231405
graph_module = PostponeDequantizeOpBelowUseChainPass()(
232-
graph_module
406+
original_graph
233407
).graph_module
234408
graph_module.graph.eliminate_dead_code()
235-
236409
# Asset that the dequant node was split into 4, one per branch
237410
self.assertEqual(
238411
count_node(
@@ -261,31 +434,35 @@ def forward(self, x):
261434

262435
# 4d -> permute -> 4d -> view -> 3d
263436
def test_permute3_view4_chains(self):
264-
class PermuteViewChain(torch.nn.Module):
265-
def forward(self, x):
266-
# x is [3, 1, 768]
267-
x = x.view((3, 12, 64))
268-
# x is [3, 12, 64]
269-
x = x.permute([1, 0, 2])
270-
# x is [12, 3, 64]
271-
x = x.view((1, 12, 3, 64))
272-
# x is [1, 12, 3, 64]
273-
x = x.permute([0, 1, 3, 2])
274-
# x is [1, 12, 64, 3]
275-
return x
276-
277-
model = PermuteViewChain()
278-
X = torch.randn(3, 1, 768)
279-
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
280-
437+
builder = GraphBuilder()
438+
x = builder.placeholder("x", torch.randn(3, 1, 768))
439+
aten_view_copy_default = builder.call_operator(
440+
op=exir_ops.edge.aten.view_copy.default,
441+
args=(x, [3, 12, 64]),
442+
)
443+
aten_permute_copy_default = builder.call_operator(
444+
op=exir_ops.edge.aten.permute_copy.default,
445+
args=(aten_view_copy_default, [1, 0, 2]),
446+
)
447+
aten_view_copy_default_1 = builder.call_operator(
448+
op=exir_ops.edge.aten.view_copy.default,
449+
args=(aten_permute_copy_default, [1, 12, 3, 64]),
450+
)
451+
aten_permute_copy_default_1 = builder.call_operator(
452+
op=exir_ops.edge.aten.permute_copy.default,
453+
args=(aten_view_copy_default_1, [0, 1, 3, 2]),
454+
)
455+
builder.output(
456+
aten_permute_copy_default_1,
457+
)
458+
original_graph = builder.get_graph_module()
281459
# Performing transform
282-
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
283-
graph_module
460+
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
461+
original_graph
284462
).graph_module
285-
graph_module.graph.eliminate_dead_code()
286-
463+
converted_graph.graph.eliminate_dead_code()
287464
# Assert the order becomes view, view, permute, permute
288-
nodes = get_compute_nodes_in_gm(graph_module)
465+
nodes = get_compute_nodes_in_gm(converted_graph)
289466
self.assertEqual(len(nodes), 4)
290467
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
291468
self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
@@ -294,31 +471,36 @@ def forward(self, x):
294471

295472
# 3d -> permute -> 3d -> view -> 4d
296473
def test_permute4_view3_chains(self):
297-
class PermuteViewChain(torch.nn.Module):
298-
def forward(self, x):
299-
# x is [3, 1, 768]
300-
x = x.view((1, 3, 12, 64))
301-
# x is [1, 3, 12, 64]
302-
x = x.permute([3, 1, 0, 2])
303-
# x is [64, 3, 1, 12]
304-
x = x.view((64, 3, 12))
305-
# x is [64, 3, 12]
306-
x = x.permute([2, 1, 0])
307-
# x is [12, 3, 64]
308-
return x
309-
310-
model = PermuteViewChain()
311-
X = torch.randn(3, 1, 768)
312-
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
313-
474+
builder = GraphBuilder()
475+
x = builder.placeholder("x", torch.randn(3, 1, 768))
476+
aten_view_copy_default = builder.call_operator(
477+
op=exir_ops.edge.aten.view_copy.default,
478+
args=(x, [1, 3, 12, 64]),
479+
)
480+
aten_permute_copy_default = builder.call_operator(
481+
op=exir_ops.edge.aten.permute_copy.default,
482+
args=(aten_view_copy_default, [3, 1, 0, 2]),
483+
)
484+
aten_view_copy_default_1 = builder.call_operator(
485+
op=exir_ops.edge.aten.view_copy.default,
486+
args=(aten_permute_copy_default, [64, 3, 12]),
487+
)
488+
aten_permute_copy_default_1 = builder.call_operator(
489+
op=exir_ops.edge.aten.permute_copy.default,
490+
args=(aten_view_copy_default_1, [2, 1, 0]),
491+
)
492+
builder.output(
493+
aten_permute_copy_default_1,
494+
)
495+
original_graph = builder.get_graph_module()
314496
# Performing transform
315-
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
316-
graph_module
497+
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
498+
original_graph
317499
).graph_module
318-
graph_module.graph.eliminate_dead_code()
500+
converted_graph.graph.eliminate_dead_code()
319501

320502
# Assert the order becomes view, view, permute, permute
321-
nodes = get_compute_nodes_in_gm(graph_module)
503+
nodes = get_compute_nodes_in_gm(converted_graph)
322504
self.assertEqual(len(nodes), 4)
323505
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
324506
self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
@@ -329,31 +511,36 @@ def forward(self, x):
329511
# permute->4d->view->3d where the view not only removes the dimension whose
330512
# size is 1 (this is ok), but also changes the size of the dimensions (not ok).
331513
def test_permute_view_chains_neg(self):
332-
class PermuteViewChain(torch.nn.Module):
333-
def forward(self, x):
334-
# x is [3, 1, 768]
335-
x = x.view((1, 3, 12, 64))
336-
# x is [1, 3, 12, 64]
337-
x = x.permute([3, 1, 0, 2])
338-
# x is [64, 3, 1, 12]
339-
x = x.view((64, 6, 6))
340-
# x is [64, 6, 6]
341-
x = x.permute([2, 1, 0])
342-
# x is [6, 6, 64]
343-
return x
344-
345-
model = PermuteViewChain()
346-
X = torch.randn(3, 1, 768)
347-
graph_module = export_to_edge(model, (X,)).exported_program().graph_module
348-
514+
builder = GraphBuilder()
515+
x = builder.placeholder("x", torch.randn(3, 1, 768))
516+
aten_view_copy_default = builder.call_operator(
517+
op=exir_ops.edge.aten.view_copy.default,
518+
args=(x, [1, 3, 12, 64]),
519+
)
520+
aten_permute_copy_default = builder.call_operator(
521+
op=exir_ops.edge.aten.permute_copy.default,
522+
args=(aten_view_copy_default, [3, 1, 0, 2]),
523+
)
524+
aten_view_copy_default_1 = builder.call_operator(
525+
op=exir_ops.edge.aten.view_copy.default,
526+
args=(aten_permute_copy_default, [64, 6, 6]),
527+
)
528+
aten_permute_copy_default_1 = builder.call_operator(
529+
op=exir_ops.edge.aten.permute_copy.default,
530+
args=(aten_view_copy_default_1, [2, 1, 0]),
531+
)
532+
builder.output(
533+
aten_permute_copy_default_1,
534+
)
535+
original_graph = builder.get_graph_module()
349536
# Performing transform (nothing should happen)
350-
graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
351-
graph_module
537+
converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
538+
original_graph
352539
).graph_module
353-
graph_module.graph.eliminate_dead_code()
540+
converted_graph.graph.eliminate_dead_code()
354541

355542
# Assert the order is still view, permute, view, permute
356-
nodes = get_compute_nodes_in_gm(graph_module)
543+
nodes = get_compute_nodes_in_gm(converted_graph)
357544
self.assertEqual(len(nodes), 4)
358545
self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
359546
self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)

0 commit comments

Comments
 (0)
Please sign in to comment.