Skip to content

Commit acad8c1

Browse files
Juntian777facebook-github-bot
authored andcommitted
Updated merging logic of AOT to get the corresponding intermediate output (#12351)
Summary: Pull Request resolved: #12351 This PR updated the mapping between AOT and runtime. The runtime_list has a size of 1 because AOT is either an integer or a list of integers with a size of 1, so no merging is needed for runtime_list during the mapping stage. However, AOT may needs to be merged during the mapping stage and the intermediate output should be retained according to the last integer of the corresponding runtime debug_handle tuples. This ensures that the AOT and runtime intermediate outputs match. Additionally, ensure that the merged AOT debug_handle exactly matches the runtime debug_handle. Reviewed By: Gasoonjia Differential Revision: D77956435
1 parent bbe90bd commit acad8c1

File tree

2 files changed

+77
-75
lines changed

2 files changed

+77
-75
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -625,23 +625,28 @@ def _debug_handles_have_overlap(
625625
return len(aot_set.intersection(runtime_set)) > 0
626626

627627

628-
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
629-
"""Combine multiple debug handles into one debug handle"""
630-
combined_debug_handles_set = set()
631-
for debug_handle in debug_handles:
632-
combined_debug_handles_set.update(set(debug_handle))
633-
return tuple(sorted(combined_debug_handles_set))
628+
def _combine_aot_overlapped_intermediate_outputs(
629+
aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any]
630+
) -> Tuple[DebugHandle, Any]:
631+
"""
632+
Ensure the AOT combined debug_handles are the same as the runtime debug handles. (order ignored)
633+
then pick the last intermediate output based on the runtime debug handles
634+
"""
635+
# Map AOT single element debug_handles to outputs
636+
aot_map = dict(aot_nodes)
637+
runtime_debug_handle, _ = runtime_node
634638

639+
# Combine all AOT debug_handles into a list
640+
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
635641

636-
def _combine_overlapped_intermediate_outputs(
637-
nodes: List[Tuple[DebugHandle, Any]]
638-
) -> Tuple[DebugHandle, Any]:
639-
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
640-
debug_handles = [debug_handle for debug_handle, _ in nodes]
641-
outputs = [output for _, output in nodes]
642-
combined_debug_handle = _combine_debug_handles(debug_handles)
643-
output = outputs[-1] # Pick the last one
644-
return combined_debug_handle, output
642+
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
643+
# AOT combined debug_handle and runtime debug_handle do not match.
644+
return (-1,), None
645+
646+
# Pick the last intermediate output
647+
last_int = runtime_debug_handle[-1]
648+
key = (last_int,)
649+
return runtime_debug_handle, aot_map[key]
645650

646651

647652
def _create_debug_handle_overlap_graph(
@@ -751,38 +756,52 @@ def map_runtime_aot_intermediate_outputs(
751756

752757
# Map only if both AOT and runtime data are present.
753758
if len(aot_list) != 0 and len(runtime_list) != 0:
759+
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760+
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761+
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762+
assert len(runtime_list) == 1
763+
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
764+
754765
# Combine aot debug handles into a single key
755766
aot_combined_debug_handle, aot_intermediate_output = (
756-
_combine_overlapped_intermediate_outputs(aot_list)
767+
_combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0])
757768
)
758-
# Combine runtime debug handles into a single key
759-
runtime_combined_debug_handle, runtime_intermediate_output = (
760-
_combine_overlapped_intermediate_outputs(runtime_list)
761-
)
762-
# List can't be used as a key, so convert to tuple
763-
if isinstance(aot_intermediate_output, list):
769+
770+
if aot_combined_debug_handle == (-1,):
771+
# Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match.
772+
continue
773+
774+
if isinstance(aot_intermediate_output, Sequence):
775+
if not isinstance(runtime_intermediate_output, Sequence):
776+
raise TypeError(
777+
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
778+
)
779+
last_element = runtime_intermediate_output[-1]
780+
if isinstance(last_element, list) and all(
781+
isinstance(t, torch.Tensor) for t in last_element
782+
):
783+
# If the last element is a list of tensors (delegate case)
784+
runtime_intermediate_output = last_element
785+
elif isinstance(last_element, torch.Tensor):
786+
# If the last element is a tensor (non-delegate case)
787+
pass
788+
else:
789+
raise ValueError(
790+
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
791+
)
792+
# List can't be used as a key, so convert to tuple
764793
aot_intermediate_output = tuple(aot_intermediate_output)
765-
# runtime follow the same format as aot, so it's safe to convert to tuple
766-
if isinstance(runtime_intermediate_output, list):
767794
runtime_intermediate_output = tuple(runtime_intermediate_output)
768795

769-
# Currently, runtime_intermediate_output logs all delegate call arguments.
770-
# Process here to extract only the outputs.
771-
if isinstance(aot_intermediate_output, tuple):
772-
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
773-
if isinstance(runtime_intermediate_output, tuple):
774-
runtime_intermediate_output = runtime_intermediate_output[
775-
-len(aot_intermediate_output) :
776-
]
777-
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
778-
elif isinstance(runtime_intermediate_output, tuple):
796+
elif isinstance(runtime_intermediate_output, Sequence):
797+
# delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list
779798
runtime_intermediate_output = runtime_intermediate_output[-1]
780799

781800
# Create a mapping between runtime and aot
782801
aot_runtime_mapping[
783802
(aot_combined_debug_handle, aot_intermediate_output)
784803
] = (
785-
runtime_combined_debug_handle,
804+
runtime_debug_handle,
786805
runtime_intermediate_output,
787806
)
788807

@@ -890,7 +909,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
890909
if is_a_sequence and is_b_sequence:
891910
# Ensure both sequences have the same length
892911
if len(a) != len(b):
893-
raise ValueError("Sequences must have the same length for comparison.")
912+
raise ValueError(
913+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
914+
)
894915

895916
# Compare each element in the sequences and return the list of results
896917
return [comparator.compare(x, y) for x, y in zip(a, b)]
@@ -899,7 +920,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
899920
return [comparator.compare(a, b)]
900921
else:
901922
# Raise an error if one is a sequence and the other is not
902-
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
923+
raise ValueError(
924+
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
925+
)
903926

904927

905928
def propagate_back_debug_handle(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
302302
}
303303
self.assertEqual(actual, expected)
304304

305-
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
306-
# Exact match between aot and runtime debug_handles
307-
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
308-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
309-
actual = map_runtime_aot_intermediate_outputs(
310-
aot_intermediate_outputs, runtime_intermediate_outputs
311-
)
312-
expected = {
313-
((0, 1), 100): ((0, 1), 150),
314-
((2, 3), 200): ((2, 3), 200),
315-
((4, 5), 300): ((4, 5), 300),
316-
}
317-
self.assertEqual(actual, expected)
318-
319305
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
320306
# No overlaps between aot and runtime debug_handles
321-
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
307+
aot_intermediate_outputs = {(0,): 100, (4,): 300}
322308
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
323309
actual = map_runtime_aot_intermediate_outputs(
324310
aot_intermediate_outputs, runtime_intermediate_outputs
325311
)
326312
expected = {}
327313
self.assertEqual(actual, expected)
328314

329-
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
330-
# Multiple aot debug_handles map to one runtime debug_handle
331-
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
332-
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
333-
actual = map_runtime_aot_intermediate_outputs(
334-
aot_intermediate_outputs, runtime_intermediate_outputs
335-
)
336-
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
337-
self.assertEqual(actual, expected)
338-
339-
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
340-
# One aot debug_handle map to multiple runtime debug_handles
341-
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
342-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
315+
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
316+
# Partial match between aot and runtime debug_handles will return empty
317+
aot_intermediate_outputs = {(2,): 100, (9,): 300}
318+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
343319
actual = map_runtime_aot_intermediate_outputs(
344320
aot_intermediate_outputs, runtime_intermediate_outputs
345321
)
346-
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
322+
expected = {}
347323
self.assertEqual(actual, expected)
348324

349-
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
350-
# Complex chain (N-to-N mapping)
351-
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
352-
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
325+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
326+
# Multiple aot debug_handles map to one runtime debug_handle
327+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400}
328+
runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300}
353329
actual = map_runtime_aot_intermediate_outputs(
354330
aot_intermediate_outputs, runtime_intermediate_outputs
355331
)
356-
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
332+
expected = {((2, 3, 1), 200): ((2, 3, 1), 250)}
357333
self.assertEqual(actual, expected)
358334

359335
def test_map_runtime_aot_intermediate_outputs_delegated(self):
360336
# Currently, runtime_intermediate_output logs all delegate call arguments
361337
# Test that the map function correctly extracted out the delegated outputs
362338
aot_intermediate_outputs = {
363-
(1, 2): torch.tensor([4, 5]),
364-
(3, 4): torch.tensor([10, 11, 12]),
365-
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
339+
(1,): torch.tensor([4, 1]),
340+
(2,): torch.tensor([4, 5]),
341+
(3,): torch.tensor([10, 10, 13]),
342+
(4,): torch.tensor([10, 11, 12]),
343+
(5,): torch.tensor([13, 14, 15, 16, 21]),
344+
(6,): torch.tensor([13, 14, 15, 16, 17]),
366345
}
367346
runtime_intermediate_outputs = {
368347
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],

0 commit comments

Comments
 (0)