@@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
302
302
}
303
303
self .assertEqual (actual , expected )
304
304
305
- def test_map_runtime_aot_intermediate_outputs_exact_match (self ):
306
- # Exact match between aot and runtime debug_handles
307
- aot_intermediate_outputs = {(0 , 1 ): 100 , (2 , 3 ): 200 , (4 , 5 ): 300 }
308
- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
309
- actual = map_runtime_aot_intermediate_outputs (
310
- aot_intermediate_outputs , runtime_intermediate_outputs
311
- )
312
- expected = {
313
- ((0 , 1 ), 100 ): ((0 , 1 ), 150 ),
314
- ((2 , 3 ), 200 ): ((2 , 3 ), 200 ),
315
- ((4 , 5 ), 300 ): ((4 , 5 ), 300 ),
316
- }
317
- self .assertEqual (actual , expected )
318
-
319
305
def test_map_runtime_aot_intermediate_outputs_no_overlaps (self ):
320
306
# No overlaps between aot and runtime debug_handles
321
- aot_intermediate_outputs = {(0 , 1 ): 100 , (4 , 5 ): 300 }
307
+ aot_intermediate_outputs = {(0 ,): 100 , (4 ,): 300 }
322
308
runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
323
309
actual = map_runtime_aot_intermediate_outputs (
324
310
aot_intermediate_outputs , runtime_intermediate_outputs
325
311
)
326
312
expected = {}
327
313
self .assertEqual (actual , expected )
328
314
329
- def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
330
- # Multiple aot debug_handles map to one runtime debug_handle
331
- aot_intermediate_outputs = {(0 , 1 , 2 ): 100 , (3 , 4 ): 300 }
332
- runtime_intermediate_outputs = {(1 , 2 , 3 ): 250 , (8 , 9 ): 300 }
333
- actual = map_runtime_aot_intermediate_outputs (
334
- aot_intermediate_outputs , runtime_intermediate_outputs
335
- )
336
- expected = {((0 , 1 , 2 , 3 , 4 ), 300 ): ((1 , 2 , 3 ), 250 )}
337
- self .assertEqual (actual , expected )
338
-
339
- def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime (self ):
340
- # One aot debug_handle map to multiple runtime debug_handles
341
- aot_intermediate_outputs = {(0 , 1 , 2 , 3 , 4 ): 100 , (8 , 9 ): 300 }
342
- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
315
+ def test_map_runtime_aot_intermediate_outputs_partial_match (self ):
316
+ # Partial match between aot and runtime debug_handles will return empty
317
+ aot_intermediate_outputs = {(2 ,): 100 , (9 ,): 300 }
318
+ runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
343
319
actual = map_runtime_aot_intermediate_outputs (
344
320
aot_intermediate_outputs , runtime_intermediate_outputs
345
321
)
346
- expected = {(( 0 , 1 , 2 , 3 , 4 ), 100 ): (( 0 , 1 , 2 , 3 , 4 , 5 ), 300 ) }
322
+ expected = {}
347
323
self .assertEqual (actual , expected )
348
324
349
- def test_map_runtime_aot_intermediate_outputs_complex_chain (self ):
350
- # Complex chain (N-to-N mapping)
351
- aot_intermediate_outputs = {(1 , 2 ): 100 , (3 , 4 ): 200 , (5 , 6 ): 300 }
352
- runtime_intermediate_outputs = {(2 , 3 ): 150 , ( 4 , 5 ): 250 , (6 , 7 ): 350 }
325
+ def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
326
+ # Multiple aot debug_handles map to one runtime debug_handle
327
+ aot_intermediate_outputs = {(0 , ): 100 , (1 , ): 200 , (2 , ): 300 , ( 3 ,): 400 }
328
+ runtime_intermediate_outputs = {(2 , 3 , 1 ): 250 , (8 , 9 ): 300 }
353
329
actual = map_runtime_aot_intermediate_outputs (
354
330
aot_intermediate_outputs , runtime_intermediate_outputs
355
331
)
356
- expected = {((1 , 2 , 3 , 4 , 5 , 6 ), 300 ): ((2 , 3 , 4 , 5 , 6 , 7 ), 350 )}
332
+ expected = {((2 , 3 , 1 ), 200 ): ((2 , 3 , 1 ), 250 )}
357
333
self .assertEqual (actual , expected )
358
334
359
335
def test_map_runtime_aot_intermediate_outputs_delegated (self ):
360
336
# Currently, runtime_intermediate_output logs all delegate call arguments
361
337
# Test that the map function correctly extracted out the delegated outputs
362
338
aot_intermediate_outputs = {
363
- (1 , 2 ): torch .tensor ([4 , 5 ]),
364
- (3 , 4 ): torch .tensor ([10 , 11 , 12 ]),
365
- (5 , 6 ): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
339
+ (1 ,): torch .tensor ([4 , 1 ]),
340
+ (2 ,): torch .tensor ([4 , 5 ]),
341
+ (3 ,): torch .tensor ([10 , 10 , 13 ]),
342
+ (4 ,): torch .tensor ([10 , 11 , 12 ]),
343
+ (5 ,): torch .tensor ([13 , 14 , 15 , 16 , 21 ]),
344
+ (6 ,): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
366
345
}
367
346
runtime_intermediate_outputs = {
368
347
(1 , 2 ): [torch .tensor ([1 , 2 , 3 ]), torch .tensor ([4 , 5 ])],
0 commit comments