Skip to content

Commit 206dcf4

Browse files
authored
Updated merging logic of AOT to get the corresponding intermediate output
Differential Revision: D77956435 Pull Request resolved: #12351
1 parent 31351b0 commit 206dcf4

File tree

2 files changed

+77
-75
lines changed

2 files changed

+77
-75
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -625,23 +625,28 @@ def _debug_handles_have_overlap(
625625
return len(aot_set.intersection(runtime_set)) > 0
626626

627627

628-
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
629-
"""Combine multiple debug handles into one debug handle"""
630-
combined_debug_handles_set = set()
631-
for debug_handle in debug_handles:
632-
combined_debug_handles_set.update(set(debug_handle))
633-
return tuple(sorted(combined_debug_handles_set))
628+
def _combine_aot_overlapped_intermediate_outputs(
629+
aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any]
630+
) -> Tuple[DebugHandle, Any]:
631+
"""
632+
Ensure the AOT combined debug_handles are the same as the runtime debug_handles (order ignored),
633+
then pick the last intermediate output based on the runtime debug_handles
634+
"""
635+
# Map AOT single element debug_handles to outputs
636+
aot_map = dict(aot_nodes)
637+
runtime_debug_handle, _ = runtime_node
634638

639+
# Combine all AOT debug_handles into a list
640+
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
635641

636-
def _combine_overlapped_intermediate_outputs(
637-
nodes: List[Tuple[DebugHandle, Any]]
638-
) -> Tuple[DebugHandle, Any]:
639-
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
640-
debug_handles = [debug_handle for debug_handle, _ in nodes]
641-
outputs = [output for _, output in nodes]
642-
combined_debug_handle = _combine_debug_handles(debug_handles)
643-
output = outputs[-1] # Pick the last one
644-
return combined_debug_handle, output
642+
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
643+
# AOT combined debug_handle and runtime debug_handle do not match.
644+
return (-1,), None
645+
646+
# Pick the last intermediate output
647+
last_int = runtime_debug_handle[-1]
648+
key = (last_int,)
649+
return runtime_debug_handle, aot_map[key]
645650

646651

647652
def _create_debug_handle_overlap_graph(
@@ -751,38 +756,52 @@ def map_runtime_aot_intermediate_outputs(
751756

752757
# Map only if both AOT and runtime data are present.
753758
if len(aot_list) != 0 and len(runtime_list) != 0:
759+
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760+
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761+
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762+
assert len(runtime_list) == 1
763+
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
764+
754765
# Combine aot debug handles into a single key
755766
aot_combined_debug_handle, aot_intermediate_output = (
756-
_combine_overlapped_intermediate_outputs(aot_list)
767+
_combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0])
757768
)
758-
# Combine runtime debug handles into a single key
759-
runtime_combined_debug_handle, runtime_intermediate_output = (
760-
_combine_overlapped_intermediate_outputs(runtime_list)
761-
)
762-
# List can't be used as a key, so convert to tuple
763-
if isinstance(aot_intermediate_output, list):
769+
770+
if aot_combined_debug_handle == (-1,):
771+
# Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match.
772+
continue
773+
774+
if isinstance(aot_intermediate_output, Sequence):
775+
if not isinstance(runtime_intermediate_output, Sequence):
776+
raise TypeError(
777+
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
778+
)
779+
last_element = runtime_intermediate_output[-1]
780+
if isinstance(last_element, list) and all(
781+
isinstance(t, torch.Tensor) for t in last_element
782+
):
783+
# If the last element is a list of tensors (delegate case)
784+
runtime_intermediate_output = last_element
785+
elif isinstance(last_element, torch.Tensor):
786+
# If the last element is a tensor (non-delegate case)
787+
pass
788+
else:
789+
raise ValueError(
790+
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
791+
)
792+
# List can't be used as a key, so convert to tuple
764793
aot_intermediate_output = tuple(aot_intermediate_output)
765-
# runtime follow the same format as aot, so it's safe to convert to tuple
766-
if isinstance(runtime_intermediate_output, list):
767794
runtime_intermediate_output = tuple(runtime_intermediate_output)
768795

769-
# Currently, runtime_intermediate_output logs all delegate call arguments.
770-
# Process here to extract only the outputs.
771-
if isinstance(aot_intermediate_output, tuple):
772-
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
773-
if isinstance(runtime_intermediate_output, tuple):
774-
runtime_intermediate_output = runtime_intermediate_output[
775-
-len(aot_intermediate_output) :
776-
]
777-
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
778-
elif isinstance(runtime_intermediate_output, tuple):
796+
elif isinstance(runtime_intermediate_output, Sequence):
797+
# delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list
779798
runtime_intermediate_output = runtime_intermediate_output[-1]
780799

781800
# Create a mapping between runtime and aot
782801
aot_runtime_mapping[
783802
(aot_combined_debug_handle, aot_intermediate_output)
784803
] = (
785-
runtime_combined_debug_handle,
804+
runtime_debug_handle,
786805
runtime_intermediate_output,
787806
)
788807

@@ -890,7 +909,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
890909
if is_a_sequence and is_b_sequence:
891910
# Ensure both sequences have the same length
892911
if len(a) != len(b):
893-
raise ValueError("Sequences must have the same length for comparison.")
912+
raise ValueError(
913+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
914+
)
894915

895916
# Compare each element in the sequences and return the list of results
896917
return [comparator.compare(x, y) for x, y in zip(a, b)]
@@ -899,7 +920,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
899920
return [comparator.compare(a, b)]
900921
else:
901922
# Raise an error if one is a sequence and the other is not
902-
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
923+
raise ValueError(
924+
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
925+
)
903926

904927

905928
def propagate_back_debug_handle(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
302302
}
303303
self.assertEqual(actual, expected)
304304

305-
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
306-
# Exact match between aot and runtime debug_handles
307-
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
308-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
309-
actual = map_runtime_aot_intermediate_outputs(
310-
aot_intermediate_outputs, runtime_intermediate_outputs
311-
)
312-
expected = {
313-
((0, 1), 100): ((0, 1), 150),
314-
((2, 3), 200): ((2, 3), 200),
315-
((4, 5), 300): ((4, 5), 300),
316-
}
317-
self.assertEqual(actual, expected)
318-
319305
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
320306
# No overlaps between aot and runtime debug_handles
321-
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
307+
aot_intermediate_outputs = {(0,): 100, (4,): 300}
322308
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
323309
actual = map_runtime_aot_intermediate_outputs(
324310
aot_intermediate_outputs, runtime_intermediate_outputs
325311
)
326312
expected = {}
327313
self.assertEqual(actual, expected)
328314

329-
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
330-
# Multiple aot debug_handles map to one runtime debug_handle
331-
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
332-
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
333-
actual = map_runtime_aot_intermediate_outputs(
334-
aot_intermediate_outputs, runtime_intermediate_outputs
335-
)
336-
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
337-
self.assertEqual(actual, expected)
338-
339-
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
340-
# One aot debug_handle map to multiple runtime debug_handles
341-
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
342-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
315+
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
316+
# Partial match between aot and runtime debug_handles will return empty
317+
aot_intermediate_outputs = {(2,): 100, (9,): 300}
318+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
343319
actual = map_runtime_aot_intermediate_outputs(
344320
aot_intermediate_outputs, runtime_intermediate_outputs
345321
)
346-
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
322+
expected = {}
347323
self.assertEqual(actual, expected)
348324

349-
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
350-
# Complex chain (N-to-N mapping)
351-
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
352-
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
325+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
326+
# Multiple aot debug_handles map to one runtime debug_handle
327+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400}
328+
runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300}
353329
actual = map_runtime_aot_intermediate_outputs(
354330
aot_intermediate_outputs, runtime_intermediate_outputs
355331
)
356-
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
332+
expected = {((2, 3, 1), 200): ((2, 3, 1), 250)}
357333
self.assertEqual(actual, expected)
358334

359335
def test_map_runtime_aot_intermediate_outputs_delegated(self):
360336
# Currently, runtime_intermediate_output logs all delegate call arguments
361337
# Test that the map function correctly extracted out the delegated outputs
362338
aot_intermediate_outputs = {
363-
(1, 2): torch.tensor([4, 5]),
364-
(3, 4): torch.tensor([10, 11, 12]),
365-
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
339+
(1,): torch.tensor([4, 1]),
340+
(2,): torch.tensor([4, 5]),
341+
(3,): torch.tensor([10, 10, 13]),
342+
(4,): torch.tensor([10, 11, 12]),
343+
(5,): torch.tensor([13, 14, 15, 16, 21]),
344+
(6,): torch.tensor([13, 14, 15, 16, 17]),
366345
}
367346
runtime_intermediate_outputs = {
368347
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],

0 commit comments

Comments
 (0)