@@ -297,67 +297,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
297
297
}
298
298
self .assertEqual (actual , expected )
299
299
300
- def test_map_runtime_aot_intermediate_outputs_exact_match (self ):
301
- # Exact match between aot and runtime debug_handles
302
- aot_intermediate_outputs = {(0 , 1 ): 100 , (2 , 3 ): 200 , (4 , 5 ): 300 }
303
- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
304
- actual = map_runtime_aot_intermediate_outputs (
305
- aot_intermediate_outputs , runtime_intermediate_outputs
306
- )
307
- expected = {
308
- ((0 , 1 ), 100 ): ((0 , 1 ), 150 ),
309
- ((2 , 3 ), 200 ): ((2 , 3 ), 200 ),
310
- ((4 , 5 ), 300 ): ((4 , 5 ), 300 ),
311
- }
312
- self .assertEqual (actual , expected )
313
-
314
300
def test_map_runtime_aot_intermediate_outputs_no_overlaps (self ):
315
301
# No overlaps between aot and runtime debug_handles
316
- aot_intermediate_outputs = {(0 , 1 ): 100 , (4 , 5 ): 300 }
302
+ aot_intermediate_outputs = {(0 ,): 100 , (4 ,): 300 }
317
303
runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
318
304
actual = map_runtime_aot_intermediate_outputs (
319
305
aot_intermediate_outputs , runtime_intermediate_outputs
320
306
)
321
307
expected = {}
322
308
self .assertEqual (actual , expected )
323
309
324
- def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
325
- # Multiple aot debug_handles map to one runtime debug_handle
326
- aot_intermediate_outputs = {(0 , 1 , 2 ): 100 , (3 , 4 ): 300 }
327
- runtime_intermediate_outputs = {(1 , 2 , 3 ): 250 , (8 , 9 ): 300 }
328
- actual = map_runtime_aot_intermediate_outputs (
329
- aot_intermediate_outputs , runtime_intermediate_outputs
330
- )
331
- expected = {((0 , 1 , 2 , 3 , 4 ), 300 ): ((1 , 2 , 3 ), 250 )}
332
- self .assertEqual (actual , expected )
310
+ def test_map_runtime_aot_intermediate_outputs_partial_match (self ):
311
+ # Partial match between aot and runtime debug_handles will raise an error
312
+ aot_intermediate_outputs = {(2 ,): 100 , (4 ,): 300 }
313
+ runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
333
314
334
- def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime (self ):
335
- # One aot debug_handle map to multiple runtime debug_handles
336
- aot_intermediate_outputs = {(0 , 1 , 2 , 3 , 4 ): 100 , (8 , 9 ): 300 }
337
- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
338
- actual = map_runtime_aot_intermediate_outputs (
339
- aot_intermediate_outputs , runtime_intermediate_outputs
340
- )
341
- expected = {((0 , 1 , 2 , 3 , 4 ), 100 ): ((0 , 1 , 2 , 3 , 4 , 5 ), 300 )}
342
- self .assertEqual (actual , expected )
315
+ with self .assertRaises (ValueError ):
316
+ map_runtime_aot_intermediate_outputs (
317
+ aot_intermediate_outputs , runtime_intermediate_outputs
318
+ )
343
319
344
- def test_map_runtime_aot_intermediate_outputs_complex_chain (self ):
345
- # Complex chain (N-to-N mapping)
346
- aot_intermediate_outputs = {(1 , 2 ): 100 , (3 , 4 ): 200 , (5 , 6 ): 300 }
347
- runtime_intermediate_outputs = {(2 , 3 ): 150 , ( 4 , 5 ): 250 , (6 , 7 ): 350 }
320
+ def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
321
+ # Multiple aot debug_handles map to one runtime debug_handle
322
+ aot_intermediate_outputs = {(0 , ): 100 , (1 , ): 200 , (2 , ): 300 , ( 3 ,): 400 }
323
+ runtime_intermediate_outputs = {(2 , 3 , 1 ): 250 , (8 , 9 ): 300 }
348
324
actual = map_runtime_aot_intermediate_outputs (
349
325
aot_intermediate_outputs , runtime_intermediate_outputs
350
326
)
351
- expected = {((1 , 2 , 3 , 4 , 5 , 6 ), 300 ): ((2 , 3 , 4 , 5 , 6 , 7 ), 350 )}
327
+ expected = {((2 , 3 , 1 ), 200 ): ((2 , 3 , 1 ), 250 )}
352
328
self .assertEqual (actual , expected )
353
329
354
330
def test_map_runtime_aot_intermediate_outputs_delegated (self ):
355
331
# Currently, runtime_intermediate_output logs all delegate call arguments
356
332
# Test that the map function correctly extracted out the delegated outputs
357
333
aot_intermediate_outputs = {
358
- (1 , 2 ): torch .tensor ([4 , 5 ]),
359
- (3 , 4 ): torch .tensor ([10 , 11 , 12 ]),
360
- (5 , 6 ): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
334
+ (1 ,): torch .tensor ([4 , 1 ]),
335
+ (2 ,): torch .tensor ([4 , 5 ]),
336
+ (3 ,): torch .tensor ([10 , 10 , 13 ]),
337
+ (4 ,): torch .tensor ([10 , 11 , 12 ]),
338
+ (5 ,): torch .tensor ([13 , 14 , 15 , 16 , 21 ]),
339
+ (6 ,): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
361
340
}
362
341
runtime_intermediate_outputs = {
363
342
(1 , 2 ): [torch .tensor ([1 , 2 , 3 ]), torch .tensor ([4 , 5 ])],
0 commit comments