Skip to content

Commit d50abdf

Browse files
Juntian777facebook-github-bot
authored andcommitted
Updated merging logic of AOT to get the corresponding intermediate output
Summary: This PR updated the mapping between AOT and runtime. The runtime_list has a size of 1 because AOT is either an integer or a list of integers with a size of 1, so no merging is needed for runtime_list during the mapping stage. However, AOT may needs to be merged during the mapping stage and the intermediate output should be retained according to the last integer of the corresponding runtime debug_handle tuples. This ensures that the AOT and runtime intermediate outputs match. Additionally, ensure that the merged AOT debug_handle exactly matches the runtime debug_handle. Differential Revision: D77956435
1 parent 540fa5d commit d50abdf

File tree

2 files changed

+67
-78
lines changed

2 files changed

+67
-78
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -616,23 +616,29 @@ def _debug_handles_have_overlap(
616616
return len(aot_set.intersection(runtime_set)) > 0
617617

618618

619-
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
620-
"""Combine multiple debug handles into one debug handle"""
621-
combined_debug_handles_set = set()
622-
for debug_handle in debug_handles:
623-
combined_debug_handles_set.update(set(debug_handle))
624-
return tuple(sorted(combined_debug_handles_set))
619+
def _combine_aot_overlapped_intermediate_outputs(
620+
aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any]
621+
) -> Tuple[DebugHandle, Any]:
622+
"""
623+
Ensure the AOT combined debug_handles are the same as the runtime debug handles. (order ignored)
624+
then pick the last intermediate output based on the runtime debug handles
625+
"""
626+
# Map aot single element debug_handles to outputs
627+
aot_map = dict(aot_nodes)
628+
runtime_debug_handle, _ = runtime_node
625629

630+
# Combine all aot debug_handles into a list
631+
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
626632

627-
def _combine_overlapped_intermediate_outputs(
628-
nodes: List[Tuple[DebugHandle, Any]]
629-
) -> Tuple[DebugHandle, Any]:
630-
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
631-
debug_handles = [debug_handle for debug_handle, _ in nodes]
632-
outputs = [output for _, output in nodes]
633-
combined_debug_handle = _combine_debug_handles(debug_handles)
634-
output = outputs[-1] # Pick the last one
635-
return combined_debug_handle, output
633+
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
634+
raise ValueError(
635+
f"Debug handles {aot_combined_debug_handle} and {runtime_debug_handle} do not match."
636+
)
637+
638+
# Pick the last intermediate output
639+
last_int = runtime_debug_handle[-1]
640+
key = (last_int,)
641+
return runtime_debug_handle, aot_map[key]
636642

637643

638644
def _create_debug_handle_overlap_graph(
@@ -742,38 +748,42 @@ def map_runtime_aot_intermediate_outputs(
742748

743749
# Map only if both AOT and runtime data are present.
744750
if len(aot_list) != 0 and len(runtime_list) != 0:
751+
# The size of runtime_list should be 1 because,
752+
# in AOT all debug_handles are tuples with only one single element.
753+
# No AOT node will map to multiple runtime nodes.
754+
assert len(runtime_list) == 1
755+
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
756+
745757
# Combine aot debug handles into a single key
746758
aot_combined_debug_handle, aot_intermediate_output = (
747-
_combine_overlapped_intermediate_outputs(aot_list)
748-
)
749-
# Combine runtime debug handles into a single key
750-
runtime_combined_debug_handle, runtime_intermediate_output = (
751-
_combine_overlapped_intermediate_outputs(runtime_list)
759+
_combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0])
752760
)
753-
# List can't be used as a key, so convert to tuple
754-
if isinstance(aot_intermediate_output, list):
761+
762+
if isinstance(aot_intermediate_output, Sequence):
763+
if not isinstance(runtime_intermediate_output, Sequence):
764+
raise TypeError("runtime intermediate output should be a sequence when aot intermediate output is a sequence")
765+
last_element = runtime_intermediate_output[-1]
766+
if isinstance(last_element, list) and all(isinstance(t, torch.Tensor) for t in last_element):
767+
# If the last element is a list of tensors (delegate case)
768+
runtime_intermediate_output = last_element
769+
elif isinstance(last_element, torch.Tensor):
770+
# If the last element is a tensor (non-delegate case)
771+
pass
772+
else:
773+
raise ValueError("The last element of runtime intermediate output sequence must be a tensor or a list of tensors")
774+
# List can't be used as a key, so convert to tuple
755775
aot_intermediate_output = tuple(aot_intermediate_output)
756-
# runtime follow the same format as aot, so it's safe to convert to tuple
757-
if isinstance(runtime_intermediate_output, list):
758776
runtime_intermediate_output = tuple(runtime_intermediate_output)
759777

760-
# Currently, runtime_intermediate_output logs all delegate call arguments.
761-
# Process here to extract only the outputs.
762-
if isinstance(aot_intermediate_output, tuple):
763-
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
764-
if isinstance(runtime_intermediate_output, tuple):
765-
runtime_intermediate_output = runtime_intermediate_output[
766-
-len(aot_intermediate_output) :
767-
]
768-
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
769-
elif isinstance(runtime_intermediate_output, tuple):
778+
elif isinstance(runtime_intermediate_output, Sequence):
779+
# delegate case, just take the last element
770780
runtime_intermediate_output = runtime_intermediate_output[-1]
771781

772782
# Create a mapping between runtime and aot
773783
aot_runtime_mapping[
774784
(aot_combined_debug_handle, aot_intermediate_output)
775785
] = (
776-
runtime_combined_debug_handle,
786+
runtime_debug_handle,
777787
runtime_intermediate_output,
778788
)
779789

@@ -878,7 +888,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
878888
if is_a_sequence and is_b_sequence:
879889
# Ensure both sequences have the same length
880890
if len(a) != len(b):
881-
raise ValueError("Sequences must have the same length for comparison.")
891+
raise ValueError(f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison.")
882892

883893
# Compare each element in the sequences and return the list of results
884894
return [comparator.compare(x, y) for x, y in zip(a, b)]
@@ -887,4 +897,4 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
887897
return [comparator.compare(a, b)]
888898
else:
889899
# Raise an error if one is a sequence and the other is not
890-
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
900+
raise ValueError(f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences.")

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -297,67 +297,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
297297
}
298298
self.assertEqual(actual, expected)
299299

300-
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
301-
# Exact match between aot and runtime debug_handles
302-
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
303-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
304-
actual = map_runtime_aot_intermediate_outputs(
305-
aot_intermediate_outputs, runtime_intermediate_outputs
306-
)
307-
expected = {
308-
((0, 1), 100): ((0, 1), 150),
309-
((2, 3), 200): ((2, 3), 200),
310-
((4, 5), 300): ((4, 5), 300),
311-
}
312-
self.assertEqual(actual, expected)
313-
314300
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
315301
# No overlaps between aot and runtime debug_handles
316-
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
302+
aot_intermediate_outputs = {(0,): 100, (4,): 300}
317303
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
318304
actual = map_runtime_aot_intermediate_outputs(
319305
aot_intermediate_outputs, runtime_intermediate_outputs
320306
)
321307
expected = {}
322308
self.assertEqual(actual, expected)
323309

324-
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
325-
# Multiple aot debug_handles map to one runtime debug_handle
326-
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
327-
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
328-
actual = map_runtime_aot_intermediate_outputs(
329-
aot_intermediate_outputs, runtime_intermediate_outputs
330-
)
331-
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
332-
self.assertEqual(actual, expected)
310+
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
311+
# Partial match between aot and runtime debug_handles will raise an error
312+
aot_intermediate_outputs = {(2,): 100, (4,): 300}
313+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
333314

334-
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
335-
# One aot debug_handle map to multiple runtime debug_handles
336-
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
337-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
338-
actual = map_runtime_aot_intermediate_outputs(
339-
aot_intermediate_outputs, runtime_intermediate_outputs
340-
)
341-
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
342-
self.assertEqual(actual, expected)
315+
with self.assertRaises(ValueError):
316+
map_runtime_aot_intermediate_outputs(
317+
aot_intermediate_outputs, runtime_intermediate_outputs
318+
)
343319

344-
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
345-
# Complex chain (N-to-N mapping)
346-
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
347-
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
320+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
321+
# Multiple aot debug_handles map to one runtime debug_handle
322+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400}
323+
runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300}
348324
actual = map_runtime_aot_intermediate_outputs(
349325
aot_intermediate_outputs, runtime_intermediate_outputs
350326
)
351-
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
327+
expected = {((2, 3, 1), 200): ((2, 3, 1), 250)}
352328
self.assertEqual(actual, expected)
353329

354330
def test_map_runtime_aot_intermediate_outputs_delegated(self):
355331
# Currently, runtime_intermediate_output logs all delegate call arguments
356332
# Test that the map function correctly extracted out the delegated outputs
357333
aot_intermediate_outputs = {
358-
(1, 2): torch.tensor([4, 5]),
359-
(3, 4): torch.tensor([10, 11, 12]),
360-
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
334+
(1,): torch.tensor([4, 1]),
335+
(2,): torch.tensor([4, 5]),
336+
(3,): torch.tensor([10, 10, 13]),
337+
(4,): torch.tensor([10, 11, 12]),
338+
(5,): torch.tensor([13, 14, 15, 16, 21]),
339+
(6,): torch.tensor([13, 14, 15, 16, 17]),
361340
}
362341
runtime_intermediate_outputs = {
363342
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],

0 commit comments

Comments
 (0)