11
11
12
12
import executorch .backends .cadence .aot .ops_registrations # noqa
13
13
import torch
14
- from executorch .backends .cadence .aot .compiler import (
15
- export_to_edge ,
16
- quantize_and_export_to_cadence ,
17
- )
14
+
18
15
from executorch .backends .cadence .aot .fuse_ops import FuseQuantDequantToRequantizePass
16
+ from executorch .backends .cadence .aot .graph_builder import GraphBuilder
19
17
from executorch .backends .cadence .aot .pass_utils import (
20
18
count_node ,
21
19
get_compute_nodes_in_gm ,
20
+ get_node_pos ,
22
21
nodes_not_adjacent_in_gm ,
23
22
nodes_not_connected_in_gm ,
24
23
)
25
24
from executorch .backends .cadence .aot .reorder_ops import (
25
+ AdvanceQuantizeOpAboveDefChainPass ,
26
26
AdvanceQuantizeOpAboveDefInBranchPass ,
27
27
PostponeDequantizeOpBelowUseChainPass ,
28
28
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ,
29
+ SinkOpsCloserToUsePass ,
29
30
)
30
31
from executorch .exir .dialects ._ops import ops as exir_ops
31
32
32
33
33
34
class TestReorderPasses (unittest .TestCase ):
34
35
def test_sink_dequantize (self ):
35
- class M (torch .nn .Module ):
36
- def __init__ (self ):
37
- super ().__init__ ()
38
- self .linear = torch .nn .Linear (6 , 12 , bias = False )
39
-
40
- def forward (self , x , y ):
41
- x1 = self .linear (x )
42
- y1 = self .linear (y )
43
- x2 = torch .ops .aten .abs (x1 )
44
- return torch .ops .aten .cat ((x2 , y1 ))
45
-
46
- inputs = (torch .randn (32 , 6 ), torch .randn (32 , 6 ))
47
- graph_module = (
48
- quantize_and_export_to_cadence (M (), inputs ).exported_program ().graph_module
36
+ builder = GraphBuilder ()
37
+ x = builder .placeholder ("x" , torch .randn (32 , 6 , dtype = torch .float32 ))
38
+ y = builder .placeholder ("y" , torch .randn (32 , 6 , dtype = torch .float32 ))
39
+ weights = builder .placeholder (
40
+ "weights" , torch .randint (- 128 , 127 , (6 , 8 ), dtype = torch .int8 )
41
+ )
42
+ x_quantized = builder .call_operator (
43
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
44
+ args = (x , 0.02252197265625 , 20 , - 128 , 127 , torch .int8 ),
45
+ )
46
+ y_quantized = builder .call_operator (
47
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
48
+ args = (y , 0.02181086875498295 , - 11 , - 128 , 127 , torch .int8 ),
49
+ )
50
+ full = builder .call_operator (
51
+ op = exir_ops .edge .aten .full .default ,
52
+ args = ([1 ], - 7 ),
53
+ )
54
+ full_1 = builder .call_operator (
55
+ op = exir_ops .edge .aten .full .default ,
56
+ args = ([1 ], 1253324672 ),
57
+ )
58
+ full_2 = builder .call_operator (
59
+ op = exir_ops .edge .aten .full .default ,
60
+ args = ([1 ], - 3 ),
61
+ )
62
+ full_3 = builder .call_operator (
63
+ op = exir_ops .edge .aten .full .default ,
64
+ args = ([1 ], 0.0 ),
65
+ )
66
+ full_4 = builder .call_operator (
67
+ op = exir_ops .edge .aten .full .default ,
68
+ args = ([1 ], - 7 ),
69
+ )
70
+ full_5 = builder .call_operator (
71
+ op = exir_ops .edge .aten .full .default ,
72
+ args = ([1 ], 1290687488 ),
73
+ )
74
+ full_6 = builder .call_operator (
75
+ op = exir_ops .edge .aten .full .default ,
76
+ args = ([1 ], - 3 ),
77
+ )
78
+ full_7 = builder .call_operator (
79
+ op = exir_ops .edge .aten .full .default ,
80
+ args = ([1 ], 0.0 ),
81
+ )
82
+ quantized_linear = builder .call_operator (
83
+ op = exir_ops .edge .cadence .quantized_linear .default ,
84
+ args = (x_quantized , weights , full_3 , 20 , full_2 , full_1 , full , 13 , None ),
49
85
)
86
+ quantized_linear_1 = builder .call_operator (
87
+ op = exir_ops .edge .cadence .quantized_linear .default ,
88
+ args = (y_quantized , weights , full_7 , - 11 , full_6 , full_5 , full_4 , 8 , None ),
89
+ )
90
+ dequantize_per_tensor = builder .call_operator (
91
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
92
+ args = (quantized_linear , 0.015294239856302738 , 13 , - 128 , 127 , torch .int8 ),
93
+ )
94
+ dequantize_per_tensor_1 = builder .call_operator (
95
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
96
+ args = (quantized_linear_1 , 0.014382584020495415 , 8 , - 128 , 127 , torch .int8 ),
97
+ )
98
+ abs_1 = builder .call_operator (
99
+ op = exir_ops .edge .aten .abs .default ,
100
+ args = (dequantize_per_tensor ,),
101
+ )
102
+ cat = builder .call_operator (
103
+ op = exir_ops .edge .aten .cat .default ,
104
+ args = ([abs_1 , dequantize_per_tensor_1 ],),
105
+ )
106
+ builder .output (cat )
107
+ original_graph = builder .get_graph_module ()
108
+ converted_graph = SinkOpsCloserToUsePass ()(original_graph ).graph_module
109
+
50
110
# Expect the SinkDequant pass to move dequant(y) from above the relu to just below it
51
111
self .assertTrue (
52
112
nodes_not_adjacent_in_gm (
53
- graph_module ,
113
+ converted_graph ,
54
114
exir_ops .edge .aten .abs .default ,
55
115
exir_ops .edge .aten .cat .default ,
56
116
),
57
117
)
58
118
self .assertTrue (
59
119
nodes_not_adjacent_in_gm (
60
- graph_module ,
120
+ converted_graph ,
61
121
exir_ops .edge .cadence .dequantize_per_tensor .default ,
62
122
exir_ops .edge .cadence .dequantize_per_tensor .default ,
63
123
),
64
124
)
65
125
66
126
def test_advance_branched_quantize (self ):
67
- class ReorderOpsBranch (torch .nn .Module ):
68
- def forward (self , x ):
69
- x = x .view ((32 , 6 ))
70
- x1 = torch .slice_copy (x , dim = 0 , start = 0 , end = 6 , step = 1 )
71
- x1 = exir_ops .edge .quantized_decomposed .quantize_per_tensor (
72
- x1 , 0.1 , 10 , 0 , 255 , torch .uint8
73
- )
74
- x2 = torch .slice_copy (x , dim = 0 , start = 6 , end = 12 , step = 1 )
75
- x2 = exir_ops .edge .quantized_decomposed .quantize_per_tensor (
76
- x2 , 0.1 , 10 , 0 , 255 , torch .uint8
77
- )
78
- x3 = torch .slice_copy (x , dim = 0 , start = 12 , end = 18 , step = 1 )
79
- x3 = exir_ops .edge .quantized_decomposed .quantize_per_tensor (
80
- x3 , 0.1 , 10 , 0 , 255 , torch .uint8
81
- )
82
- x4 = torch .slice_copy (x , dim = 0 , start = 18 , end = 24 , step = 1 )
83
- x4 = exir_ops .edge .quantized_decomposed .quantize_per_tensor (
84
- x4 , 0.2 , 4 , 0 , 255 , torch .uint8
85
- )
86
- return (x1 , x2 , x3 , x4 )
127
+ builder = GraphBuilder ()
128
+ x = builder .placeholder ("x" , torch .randn (64 , 3 , dtype = torch .float32 ))
129
+ view = builder .call_operator (
130
+ op = exir_ops .edge .aten .view_copy .default ,
131
+ args = (x , [32 , 6 ]),
132
+ )
133
+ aten_slice_copy_tensor = builder .call_operator (
134
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
135
+ args = (view , 0 , 0 , 6 ),
136
+ )
137
+ quantized_decomposed_quantize_per_tensor_default = builder .call_operator (
138
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
139
+ args = (aten_slice_copy_tensor , 0.1 , 10 , 0 , 255 , torch .uint8 ),
140
+ )
87
141
88
- model = ReorderOpsBranch ()
89
- X = torch .randn (64 , 3 )
90
- graph_module = export_to_edge (model , (X ,)).exported_program ().graph_module
142
+ aten_slice_copy_tensor_1 = builder .call_operator (
143
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
144
+ args = (view , 0 , 6 , 12 ),
145
+ )
146
+ quantized_decomposed_quantize_per_tensor_default_1 = builder .call_operator (
147
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
148
+ args = (aten_slice_copy_tensor_1 , 0.1 , 10 , 0 , 255 , torch .uint8 ),
149
+ )
150
+
151
+ aten_slice_copy_tensor_2 = builder .call_operator (
152
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
153
+ args = (view , 0 , 12 , 18 ),
154
+ )
155
+ quantized_decomposed_quantize_per_tensor_default_2 = builder .call_operator (
156
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
157
+ args = (aten_slice_copy_tensor_2 , 0.1 , 10 , 0 , 255 , torch .uint8 ),
158
+ )
159
+
160
+ aten_slice_copy_tensor_3 = builder .call_operator (
161
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
162
+ args = (view , 0 , 18 , 24 ),
163
+ )
164
+ quantized_decomposed_quantize_per_tensor_default_3 = builder .call_operator (
165
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
166
+ args = (aten_slice_copy_tensor_3 , 0.2 , 4 , 0 , 255 , torch .uint8 ),
167
+ )
168
+ builder .output (
169
+ [
170
+ quantized_decomposed_quantize_per_tensor_default ,
171
+ quantized_decomposed_quantize_per_tensor_default_1 ,
172
+ quantized_decomposed_quantize_per_tensor_default_2 ,
173
+ quantized_decomposed_quantize_per_tensor_default_3 ,
174
+ ]
175
+ )
176
+ original_graph = builder .get_graph_module ()
91
177
graph_module = AdvanceQuantizeOpAboveDefInBranchPass ()(
92
- graph_module
178
+ original_graph
93
179
).graph_module
94
180
graph_module .graph .eliminate_dead_code ()
95
181
nodes = get_compute_nodes_in_gm (graph_module )
@@ -135,104 +221,191 @@ def forward(self, x):
135
221
136
222
@torch .no_grad ()
137
223
def test_advance_quantize (self ):
138
- class ReorderOpsChain (torch .nn .Module ):
139
- def __init__ (self ):
140
- super ().__init__ ()
141
- self .linear = torch .nn .Linear (6 , 12 , bias = False )
142
-
143
- def forward (self , x ):
144
- x = x .permute ([1 , 0 , 3 , 2 ])
145
- x = self .linear (x )
146
- return x
147
-
148
- model = ReorderOpsChain ()
149
- X = torch .randn (16 , 1 , 6 , 32 )
150
-
151
- graph_module = (
152
- quantize_and_export_to_cadence (model , (X ,)).exported_program ().graph_module
224
+ builder = GraphBuilder ()
225
+ x = builder .placeholder ("x" , torch .randn (16 , 1 , 6 , 32 , dtype = torch .float32 ))
226
+ weights = builder .placeholder (
227
+ "weights" , torch .randint (- 128 , 127 , (32 , 32 ), dtype = torch .int8 )
153
228
)
154
- # Assert that the quant node is no longer the successor of
155
- # permute node.
156
- self .assertTrue (
157
- nodes_not_connected_in_gm (
158
- graph_module ,
159
- exir_ops .edge .aten .permute_copy .default ,
160
- exir_ops .edge .cadence .quantize_per_tensor .default ,
161
- ),
229
+ full = builder .call_operator (
230
+ op = exir_ops .edge .aten .full .default ,
231
+ args = ([1 ], - 7 ),
162
232
)
163
- # Assert that permute node is the successor of quant node
164
- self .assertFalse (
165
- nodes_not_connected_in_gm (
166
- graph_module ,
167
- exir_ops .edge .cadence .quantize_per_tensor .default ,
168
- exir_ops .edge .aten .permute_copy .default ,
233
+ full_1 = builder .call_operator (
234
+ op = exir_ops .edge .aten .full .default ,
235
+ args = ([1 ], 1525501056 ),
236
+ )
237
+ full_2 = builder .call_operator (
238
+ op = exir_ops .edge .aten .full .default ,
239
+ args = ([1 ], 2 ),
240
+ )
241
+ full_3 = builder .call_operator (
242
+ op = exir_ops .edge .aten .full .default ,
243
+ args = ([12 ], 0.0 ),
244
+ )
245
+ permute = builder .call_operator (
246
+ op = exir_ops .edge .aten .permute_copy .default ,
247
+ args = (x , [1 , 0 , 3 , 2 ]),
248
+ )
249
+ quantize_per_tensor = builder .call_operator (
250
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
251
+ args = (permute , 0.029049983248114586 , - 1 , - 128 , 127 , torch .int8 ),
252
+ )
253
+ quantized_linear = builder .call_operator (
254
+ op = exir_ops .edge .cadence .quantized_linear .default ,
255
+ args = (
256
+ quantize_per_tensor ,
257
+ weights ,
258
+ full_3 ,
259
+ - 1 ,
260
+ full_2 ,
261
+ full_1 ,
262
+ full ,
263
+ - 7 ,
264
+ None ,
169
265
),
170
266
)
171
-
172
- def test_postpone_dequantize (self ):
173
- class ReorderOpsChain (torch .nn .Module ):
174
- def __init__ (self ):
175
- super ().__init__ ()
176
- self .linear = torch .nn .Linear (6 , 12 , bias = False )
177
-
178
- def forward (self , x ):
179
- x = self .linear (x )
180
- x = x .permute ([1 , 0 , 3 , 2 ])
181
- return x
182
-
183
- model = ReorderOpsChain ()
184
- X = torch .randn (1 , 16 , 32 , 6 )
185
-
186
- graph_module = (
187
- quantize_and_export_to_cadence (model , (X ,)).exported_program ().graph_module
267
+ dequantize_per_tensor = builder .call_operator (
268
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
269
+ args = (quantized_linear , 0.01627226173877716 , - 7 , - 128 , 127 , torch .int8 ),
188
270
)
189
- # Assert that the dequant node is no longer the predecessor of the permute node
271
+ builder .output (dequantize_per_tensor )
272
+ original_graph = builder .get_graph_module ()
273
+ converted_graph = AdvanceQuantizeOpAboveDefInBranchPass ()(
274
+ original_graph
275
+ ).graph_module
276
+ converted_graph = AdvanceQuantizeOpAboveDefChainPass ()(
277
+ original_graph
278
+ ).graph_module
279
+ # Assert that permute node is now the successor of the quant node.
190
280
self .assertTrue (
191
- nodes_not_connected_in_gm (
192
- graph_module ,
193
- exir_ops .edge .cadence .dequantize_per_tensor .default ,
194
- exir_ops .edge .aten .permute_copy .default ,
195
- ),
281
+ get_node_pos (
282
+ converted_graph , exir_ops .edge .cadence .quantize_per_tensor .default
283
+ )
284
+ < get_node_pos (converted_graph , exir_ops .edge .aten .permute_copy .default )
196
285
)
197
- # Assert that dequant node is the successor of permute node
198
- self .assertFalse (
199
- nodes_not_connected_in_gm (
200
- graph_module ,
201
- exir_ops .edge .aten .permute_copy .default ,
202
- exir_ops .edge .cadence .dequantize_per_tensor .default ,
286
+
287
+ def test_postpone_dequantize1 (self ):
288
+ builder = GraphBuilder ()
289
+ x = builder .placeholder ("x" , torch .randn (1 , 16 , 32 , 6 , dtype = torch .float32 ))
290
+ weights = builder .placeholder (
291
+ "weights" , torch .randint (- 128 , 127 , (6 , 6 ), dtype = torch .int8 )
292
+ )
293
+ full = builder .call_operator (
294
+ op = exir_ops .edge .aten .full .default ,
295
+ args = ([1 ], - 7 ),
296
+ )
297
+ full_1 = builder .call_operator (
298
+ op = exir_ops .edge .aten .full .default ,
299
+ args = ([1 ], 1461148032 ),
300
+ )
301
+ full_2 = builder .call_operator (
302
+ op = exir_ops .edge .aten .full .default ,
303
+ args = ([1 ], - 4 ),
304
+ )
305
+ full_3 = builder .call_operator (
306
+ op = exir_ops .edge .aten .full .default ,
307
+ args = ([12 ], 0.0 ),
308
+ )
309
+ quantize_per_tensor = builder .call_operator (
310
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
311
+ args = (x , 0.029049983248114586 , - 1 , - 128 , 127 , torch .int8 ),
312
+ )
313
+ quantized_linear = builder .call_operator (
314
+ op = exir_ops .edge .cadence .quantized_linear .default ,
315
+ args = (
316
+ quantize_per_tensor ,
317
+ weights ,
318
+ full_3 ,
319
+ - 8 ,
320
+ full_2 ,
321
+ full_1 ,
322
+ full ,
323
+ 0 ,
324
+ None ,
203
325
),
204
326
)
327
+ dequantize_per_tensor = builder .call_operator (
328
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
329
+ args = (quantized_linear , 0.01627226173877716 , - 7 , - 128 , 127 , torch .int8 ),
330
+ )
331
+ permute = builder .call_operator (
332
+ op = exir_ops .edge .aten .permute_copy .default ,
333
+ args = (dequantize_per_tensor , [1 , 0 , 3 , 2 ]),
334
+ )
335
+ builder .output (permute )
336
+ original_graph = builder .get_graph_module ()
337
+ converted_graph = PostponeDequantizeOpBelowUseChainPass ()(
338
+ original_graph
339
+ ).graph_module
340
+ # Assert that dequant node is now the successor of the permute node.
341
+ self .assertTrue (
342
+ get_node_pos (converted_graph , exir_ops .edge .aten .permute_copy .default )
343
+ < get_node_pos (
344
+ converted_graph , exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
345
+ )
346
+ )
205
347
206
348
def test_postpone_dequantize_branched (self ):
207
- class ReorderOpsBranch ( torch . nn . Module ):
208
- def __init__ ( self ):
209
- super (). __init__ ( )
210
- self . linear = torch . nn . Linear ( 3 , 12 , bias = False )
211
-
212
- def forward ( self , x ):
213
- x0 = exir_ops . edge . quantized_decomposed . dequantize_per_tensor (
214
- x , 0.1 , 10 , 0 , 255 , torch . uint8
215
- )
216
- x0 = torch .squeeze ( x0 , 0 )
217
- x1 = torch . slice_copy ( x0 , dim = 0 , start = 0 , end = 6 , step = 1 )
218
- x1 = self . linear ( x1 )
219
-
220
- x2 = torch . slice_copy ( x0 , dim = 0 , start = 6 , end = 12 , step = 1 )
221
- x2 = self . linear ( x2 )
349
+ builder = GraphBuilder ()
350
+ x = builder . placeholder (
351
+ "x" , torch . randint ( 0 , 255 , [ 1 , 18 , 3 ], dtype = torch . uint8 )
352
+ )
353
+ p_linear_weight = builder . placeholder (
354
+ "weights" , torch . randint ( - 128 , 127 , ( 3 , 3 ), dtype = torch . int8 )
355
+ )
356
+ quantized_decomposed_dequantize_per_tensor_default = builder . call_operator (
357
+ op = exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
358
+ args = ( x , 0.1 , 10 , 0 , 255 , torch .uint8 ),
359
+ )
360
+ aten_squeeze_copy_dims = builder . call_operator (
361
+ op = exir_ops . edge . aten . squeeze_copy . dims ,
362
+ args = ( quantized_decomposed_dequantize_per_tensor_default , [ 0 ]),
363
+ )
222
364
223
- x3 = torch .slice_copy (x0 , dim = 0 , start = 12 , end = 18 , step = 1 )
224
- x3 = self .linear (x3 )
365
+ aten_slice_copy_tensor = builder .call_operator (
366
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
367
+ args = (aten_squeeze_copy_dims , 0 , 0 , 6 ),
368
+ )
369
+ aten_permute_copy_default = builder .call_operator (
370
+ op = exir_ops .edge .aten .permute_copy .default ,
371
+ args = (p_linear_weight , [1 , 0 ]),
372
+ )
373
+ aten_mm_default = builder .call_operator (
374
+ op = exir_ops .edge .aten .mm .default ,
375
+ args = (aten_slice_copy_tensor , aten_permute_copy_default ),
376
+ )
225
377
226
- return (x1 , x2 , x3 )
378
+ aten_slice_copy_tensor_1 = builder .call_operator (
379
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
380
+ args = (aten_squeeze_copy_dims , 0 , 6 , 12 ),
381
+ )
382
+ aten_permute_copy_default_1 = builder .call_operator (
383
+ op = exir_ops .edge .aten .permute_copy .default ,
384
+ args = (p_linear_weight , [1 , 0 ]),
385
+ )
386
+ aten_mm_default_1 = builder .call_operator (
387
+ op = exir_ops .edge .aten .mm .default ,
388
+ args = (aten_slice_copy_tensor_1 , aten_permute_copy_default_1 ),
389
+ )
227
390
228
- model = ReorderOpsBranch ()
229
- X = torch .randint (0 , 255 , [1 , 18 , 3 ], dtype = torch .uint8 )
230
- graph_module = export_to_edge (model , (X ,)).exported_program ().graph_module
391
+ aten_slice_copy_tensor_2 = builder .call_operator (
392
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
393
+ args = (aten_squeeze_copy_dims , 0 , 12 , 18 ),
394
+ )
395
+ aten_permute_copy_default_2 = builder .call_operator (
396
+ op = exir_ops .edge .aten .permute_copy .default ,
397
+ args = (p_linear_weight , [1 , 0 ]),
398
+ )
399
+ aten_mm_default_2 = builder .call_operator (
400
+ op = exir_ops .edge .aten .mm .default ,
401
+ args = (aten_slice_copy_tensor_2 , aten_permute_copy_default_2 ),
402
+ )
403
+ builder .output ([aten_mm_default , aten_mm_default_1 , aten_mm_default_2 ])
404
+ original_graph = builder .get_graph_module ()
231
405
graph_module = PostponeDequantizeOpBelowUseChainPass ()(
232
- graph_module
406
+ original_graph
233
407
).graph_module
234
408
graph_module .graph .eliminate_dead_code ()
235
-
236
409
# Asset that the dequant node was split into 4, one per branch
237
410
self .assertEqual (
238
411
count_node (
@@ -261,31 +434,35 @@ def forward(self, x):
261
434
262
435
# 4d -> permute -> 4d -> view -> 3d
263
436
def test_permute3_view4_chains (self ):
264
- class PermuteViewChain (torch .nn .Module ):
265
- def forward (self , x ):
266
- # x is [3, 1, 768]
267
- x = x .view ((3 , 12 , 64 ))
268
- # x is [3, 12, 64]
269
- x = x .permute ([1 , 0 , 2 ])
270
- # x is [12, 3, 64]
271
- x = x .view ((1 , 12 , 3 , 64 ))
272
- # x is [1, 12, 3, 64]
273
- x = x .permute ([0 , 1 , 3 , 2 ])
274
- # x is [1, 12, 64, 3]
275
- return x
276
-
277
- model = PermuteViewChain ()
278
- X = torch .randn (3 , 1 , 768 )
279
- graph_module = export_to_edge (model , (X ,)).exported_program ().graph_module
280
-
437
+ builder = GraphBuilder ()
438
+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 768 ))
439
+ aten_view_copy_default = builder .call_operator (
440
+ op = exir_ops .edge .aten .view_copy .default ,
441
+ args = (x , [3 , 12 , 64 ]),
442
+ )
443
+ aten_permute_copy_default = builder .call_operator (
444
+ op = exir_ops .edge .aten .permute_copy .default ,
445
+ args = (aten_view_copy_default , [1 , 0 , 2 ]),
446
+ )
447
+ aten_view_copy_default_1 = builder .call_operator (
448
+ op = exir_ops .edge .aten .view_copy .default ,
449
+ args = (aten_permute_copy_default , [1 , 12 , 3 , 64 ]),
450
+ )
451
+ aten_permute_copy_default_1 = builder .call_operator (
452
+ op = exir_ops .edge .aten .permute_copy .default ,
453
+ args = (aten_view_copy_default_1 , [0 , 1 , 3 , 2 ]),
454
+ )
455
+ builder .output (
456
+ aten_permute_copy_default_1 ,
457
+ )
458
+ original_graph = builder .get_graph_module ()
281
459
# Performing transform
282
- graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
283
- graph_module
460
+ converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
461
+ original_graph
284
462
).graph_module
285
- graph_module .graph .eliminate_dead_code ()
286
-
463
+ converted_graph .graph .eliminate_dead_code ()
287
464
# Assert the order becomes view, view, permute, permute
288
- nodes = get_compute_nodes_in_gm (graph_module )
465
+ nodes = get_compute_nodes_in_gm (converted_graph )
289
466
self .assertEqual (len (nodes ), 4 )
290
467
self .assertTrue (nodes [0 ] == exir_ops .edge .aten .view_copy )
291
468
self .assertTrue (nodes [1 ] == exir_ops .edge .aten .view_copy )
@@ -294,31 +471,36 @@ def forward(self, x):
294
471
295
472
# 3d -> permute -> 3d -> view -> 4d
296
473
def test_permute4_view3_chains (self ):
297
- class PermuteViewChain (torch .nn .Module ):
298
- def forward (self , x ):
299
- # x is [3, 1, 768]
300
- x = x .view ((1 , 3 , 12 , 64 ))
301
- # x is [1, 3, 12, 64]
302
- x = x .permute ([3 , 1 , 0 , 2 ])
303
- # x is [64, 3, 1, 12]
304
- x = x .view ((64 , 3 , 12 ))
305
- # x is [64, 3, 12]
306
- x = x .permute ([2 , 1 , 0 ])
307
- # x is [12, 3, 64]
308
- return x
309
-
310
- model = PermuteViewChain ()
311
- X = torch .randn (3 , 1 , 768 )
312
- graph_module = export_to_edge (model , (X ,)).exported_program ().graph_module
313
-
474
+ builder = GraphBuilder ()
475
+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 768 ))
476
+ aten_view_copy_default = builder .call_operator (
477
+ op = exir_ops .edge .aten .view_copy .default ,
478
+ args = (x , [1 , 3 , 12 , 64 ]),
479
+ )
480
+ aten_permute_copy_default = builder .call_operator (
481
+ op = exir_ops .edge .aten .permute_copy .default ,
482
+ args = (aten_view_copy_default , [3 , 1 , 0 , 2 ]),
483
+ )
484
+ aten_view_copy_default_1 = builder .call_operator (
485
+ op = exir_ops .edge .aten .view_copy .default ,
486
+ args = (aten_permute_copy_default , [64 , 3 , 12 ]),
487
+ )
488
+ aten_permute_copy_default_1 = builder .call_operator (
489
+ op = exir_ops .edge .aten .permute_copy .default ,
490
+ args = (aten_view_copy_default_1 , [2 , 1 , 0 ]),
491
+ )
492
+ builder .output (
493
+ aten_permute_copy_default_1 ,
494
+ )
495
+ original_graph = builder .get_graph_module ()
314
496
# Performing transform
315
- graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
316
- graph_module
497
+ converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
498
+ original_graph
317
499
).graph_module
318
- graph_module .graph .eliminate_dead_code ()
500
+ converted_graph .graph .eliminate_dead_code ()
319
501
320
502
# Assert the order becomes view, view, permute, permute
321
- nodes = get_compute_nodes_in_gm (graph_module )
503
+ nodes = get_compute_nodes_in_gm (converted_graph )
322
504
self .assertEqual (len (nodes ), 4 )
323
505
self .assertTrue (nodes [0 ] == exir_ops .edge .aten .view_copy )
324
506
self .assertTrue (nodes [1 ] == exir_ops .edge .aten .view_copy )
@@ -329,31 +511,36 @@ def forward(self, x):
329
511
# permute->4d->view->3d where the view not only removes the dimension whose
330
512
# size is 1 (this is ok), but also changes the size of the dimensions (not ok).
331
513
def test_permute_view_chains_neg (self ):
332
- class PermuteViewChain (torch .nn .Module ):
333
- def forward (self , x ):
334
- # x is [3, 1, 768]
335
- x = x .view ((1 , 3 , 12 , 64 ))
336
- # x is [1, 3, 12, 64]
337
- x = x .permute ([3 , 1 , 0 , 2 ])
338
- # x is [64, 3, 1, 12]
339
- x = x .view ((64 , 6 , 6 ))
340
- # x is [64, 6, 6]
341
- x = x .permute ([2 , 1 , 0 ])
342
- # x is [6, 6, 64]
343
- return x
344
-
345
- model = PermuteViewChain ()
346
- X = torch .randn (3 , 1 , 768 )
347
- graph_module = export_to_edge (model , (X ,)).exported_program ().graph_module
348
-
514
+ builder = GraphBuilder ()
515
+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 768 ))
516
+ aten_view_copy_default = builder .call_operator (
517
+ op = exir_ops .edge .aten .view_copy .default ,
518
+ args = (x , [1 , 3 , 12 , 64 ]),
519
+ )
520
+ aten_permute_copy_default = builder .call_operator (
521
+ op = exir_ops .edge .aten .permute_copy .default ,
522
+ args = (aten_view_copy_default , [3 , 1 , 0 , 2 ]),
523
+ )
524
+ aten_view_copy_default_1 = builder .call_operator (
525
+ op = exir_ops .edge .aten .view_copy .default ,
526
+ args = (aten_permute_copy_default , [64 , 6 , 6 ]),
527
+ )
528
+ aten_permute_copy_default_1 = builder .call_operator (
529
+ op = exir_ops .edge .aten .permute_copy .default ,
530
+ args = (aten_view_copy_default_1 , [2 , 1 , 0 ]),
531
+ )
532
+ builder .output (
533
+ aten_permute_copy_default_1 ,
534
+ )
535
+ original_graph = builder .get_graph_module ()
349
536
# Performing transform (nothing should happen)
350
- graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
351
- graph_module
537
+ converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ()(
538
+ original_graph
352
539
).graph_module
353
- graph_module .graph .eliminate_dead_code ()
540
+ converted_graph .graph .eliminate_dead_code ()
354
541
355
542
# Assert the order is still view, permute, view, permute
356
- nodes = get_compute_nodes_in_gm (graph_module )
543
+ nodes = get_compute_nodes_in_gm (converted_graph )
357
544
self .assertEqual (len (nodes ), 4 )
358
545
self .assertTrue (nodes [0 ] == exir_ops .edge .aten .view_copy )
359
546
self .assertTrue (nodes [1 ] == exir_ops .edge .aten .permute_copy )
0 commit comments