Skip to content

Commit d68f44f

Browse files
Juntian777facebook-github-bot
authored andcommitted
Updated merging logic of AOT to get the corresponding intermediate output (#12351)
Summary: This PR updated the mapping between AOT and runtime. The runtime_list has a size of 1 because AOT is either an integer or a list of integers with a size of 1, so no merging is needed for runtime_list during the mapping stage. However, AOT may needs to be merged during the mapping stage and the intermediate output should be retained according to the last integer of the corresponding runtime debug_handle tuples. This ensures that the AOT and runtime intermediate outputs match. Additionally, ensure that the merged AOT debug_handle exactly matches the runtime debug_handle. Differential Revision: D77956435
1 parent 3afd18d commit d68f44f

File tree

2 files changed

+77
-78
lines changed

2 files changed

+77
-78
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -625,23 +625,29 @@ def _debug_handles_have_overlap(
625625
return len(aot_set.intersection(runtime_set)) > 0
626626

627627

628-
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
629-
"""Combine multiple debug handles into one debug handle"""
630-
combined_debug_handles_set = set()
631-
for debug_handle in debug_handles:
632-
combined_debug_handles_set.update(set(debug_handle))
633-
return tuple(sorted(combined_debug_handles_set))
628+
def _combine_aot_overlapped_intermediate_outputs(
629+
aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any]
630+
) -> Tuple[DebugHandle, Any]:
631+
"""
632+
Ensure the AOT combined debug_handles are the same as the runtime debug handles. (order ignored)
633+
then pick the last intermediate output based on the runtime debug handles
634+
"""
635+
# Map aot single element debug_handles to outputs
636+
aot_map = dict(aot_nodes)
637+
runtime_debug_handle, _ = runtime_node
634638

639+
# Combine all aot debug_handles into a list
640+
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
635641

636-
def _combine_overlapped_intermediate_outputs(
637-
nodes: List[Tuple[DebugHandle, Any]]
638-
) -> Tuple[DebugHandle, Any]:
639-
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
640-
debug_handles = [debug_handle for debug_handle, _ in nodes]
641-
outputs = [output for _, output in nodes]
642-
combined_debug_handle = _combine_debug_handles(debug_handles)
643-
output = outputs[-1] # Pick the last one
644-
return combined_debug_handle, output
642+
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
643+
raise ValueError(
644+
f"Debug handles {aot_combined_debug_handle} and {runtime_debug_handle} do not match."
645+
)
646+
647+
# Pick the last intermediate output
648+
last_int = runtime_debug_handle[-1]
649+
key = (last_int,)
650+
return runtime_debug_handle, aot_map[key]
645651

646652

647653
def _create_debug_handle_overlap_graph(
@@ -751,38 +757,48 @@ def map_runtime_aot_intermediate_outputs(
751757

752758
# Map only if both AOT and runtime data are present.
753759
if len(aot_list) != 0 and len(runtime_list) != 0:
760+
# The size of runtime_list should be 1 because,
761+
# in AOT all debug_handles are tuples with only one single element.
762+
# No AOT node will map to multiple runtime nodes.
763+
assert len(runtime_list) == 1
764+
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
765+
754766
# Combine aot debug handles into a single key
755767
aot_combined_debug_handle, aot_intermediate_output = (
756-
_combine_overlapped_intermediate_outputs(aot_list)
768+
_combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0])
757769
)
758-
# Combine runtime debug handles into a single key
759-
runtime_combined_debug_handle, runtime_intermediate_output = (
760-
_combine_overlapped_intermediate_outputs(runtime_list)
761-
)
762-
# List can't be used as a key, so convert to tuple
763-
if isinstance(aot_intermediate_output, list):
770+
771+
if isinstance(aot_intermediate_output, Sequence):
772+
if not isinstance(runtime_intermediate_output, Sequence):
773+
raise TypeError(
774+
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
775+
)
776+
last_element = runtime_intermediate_output[-1]
777+
if isinstance(last_element, list) and all(
778+
isinstance(t, torch.Tensor) for t in last_element
779+
):
780+
# If the last element is a list of tensors (delegate case)
781+
runtime_intermediate_output = last_element
782+
elif isinstance(last_element, torch.Tensor):
783+
# If the last element is a tensor (non-delegate case)
784+
pass
785+
else:
786+
raise ValueError(
787+
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
788+
)
789+
# List can't be used as a key, so convert to tuple
764790
aot_intermediate_output = tuple(aot_intermediate_output)
765-
# runtime follow the same format as aot, so it's safe to convert to tuple
766-
if isinstance(runtime_intermediate_output, list):
767791
runtime_intermediate_output = tuple(runtime_intermediate_output)
768792

769-
# Currently, runtime_intermediate_output logs all delegate call arguments.
770-
# Process here to extract only the outputs.
771-
if isinstance(aot_intermediate_output, tuple):
772-
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
773-
if isinstance(runtime_intermediate_output, tuple):
774-
runtime_intermediate_output = runtime_intermediate_output[
775-
-len(aot_intermediate_output) :
776-
]
777-
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
778-
elif isinstance(runtime_intermediate_output, tuple):
793+
elif isinstance(runtime_intermediate_output, Sequence):
794+
# delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list
779795
runtime_intermediate_output = runtime_intermediate_output[-1]
780796

781797
# Create a mapping between runtime and aot
782798
aot_runtime_mapping[
783799
(aot_combined_debug_handle, aot_intermediate_output)
784800
] = (
785-
runtime_combined_debug_handle,
801+
runtime_debug_handle,
786802
runtime_intermediate_output,
787803
)
788804

@@ -890,7 +906,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
890906
if is_a_sequence and is_b_sequence:
891907
# Ensure both sequences have the same length
892908
if len(a) != len(b):
893-
raise ValueError("Sequences must have the same length for comparison.")
909+
raise ValueError(
910+
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
911+
)
894912

895913
# Compare each element in the sequences and return the list of results
896914
return [comparator.compare(x, y) for x, y in zip(a, b)]
@@ -899,7 +917,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
899917
return [comparator.compare(a, b)]
900918
else:
901919
# Raise an error if one is a sequence and the other is not
902-
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
920+
raise ValueError(
921+
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
922+
)
903923

904924

905925
def propagate_back_debug_handle(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
302302
}
303303
self.assertEqual(actual, expected)
304304

305-
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
306-
# Exact match between aot and runtime debug_handles
307-
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
308-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
309-
actual = map_runtime_aot_intermediate_outputs(
310-
aot_intermediate_outputs, runtime_intermediate_outputs
311-
)
312-
expected = {
313-
((0, 1), 100): ((0, 1), 150),
314-
((2, 3), 200): ((2, 3), 200),
315-
((4, 5), 300): ((4, 5), 300),
316-
}
317-
self.assertEqual(actual, expected)
318-
319305
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
320306
# No overlaps between aot and runtime debug_handles
321-
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
307+
aot_intermediate_outputs = {(0,): 100, (4,): 300}
322308
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
323309
actual = map_runtime_aot_intermediate_outputs(
324310
aot_intermediate_outputs, runtime_intermediate_outputs
325311
)
326312
expected = {}
327313
self.assertEqual(actual, expected)
328314

329-
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
330-
# Multiple aot debug_handles map to one runtime debug_handle
331-
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
332-
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
333-
actual = map_runtime_aot_intermediate_outputs(
334-
aot_intermediate_outputs, runtime_intermediate_outputs
335-
)
336-
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
337-
self.assertEqual(actual, expected)
315+
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
316+
# Partial match between aot and runtime debug_handles will raise an error
317+
aot_intermediate_outputs = {(2,): 100, (4,): 300}
318+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
338319

339-
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
340-
# One aot debug_handle map to multiple runtime debug_handles
341-
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
342-
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
343-
actual = map_runtime_aot_intermediate_outputs(
344-
aot_intermediate_outputs, runtime_intermediate_outputs
345-
)
346-
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
347-
self.assertEqual(actual, expected)
320+
with self.assertRaises(ValueError):
321+
map_runtime_aot_intermediate_outputs(
322+
aot_intermediate_outputs, runtime_intermediate_outputs
323+
)
348324

349-
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
350-
# Complex chain (N-to-N mapping)
351-
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
352-
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
325+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
326+
# Multiple aot debug_handles map to one runtime debug_handle
327+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400}
328+
runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300}
353329
actual = map_runtime_aot_intermediate_outputs(
354330
aot_intermediate_outputs, runtime_intermediate_outputs
355331
)
356-
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
332+
expected = {((2, 3, 1), 200): ((2, 3, 1), 250)}
357333
self.assertEqual(actual, expected)
358334

359335
def test_map_runtime_aot_intermediate_outputs_delegated(self):
360336
# Currently, runtime_intermediate_output logs all delegate call arguments
361337
# Test that the map function correctly extracted out the delegated outputs
362338
aot_intermediate_outputs = {
363-
(1, 2): torch.tensor([4, 5]),
364-
(3, 4): torch.tensor([10, 11, 12]),
365-
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
339+
(1,): torch.tensor([4, 1]),
340+
(2,): torch.tensor([4, 5]),
341+
(3,): torch.tensor([10, 10, 13]),
342+
(4,): torch.tensor([10, 11, 12]),
343+
(5,): torch.tensor([13, 14, 15, 16, 21]),
344+
(6,): torch.tensor([13, 14, 15, 16, 17]),
366345
}
367346
runtime_intermediate_outputs = {
368347
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],

0 commit comments

Comments
 (0)