Skip to content

Updated merging logic of AOT to get the corresponding intermediate output #12351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,23 +625,28 @@ def _debug_handles_have_overlap(
return len(aot_set.intersection(runtime_set)) > 0


def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
"""Combine multiple debug handles into one debug handle"""
combined_debug_handles_set = set()
for debug_handle in debug_handles:
combined_debug_handles_set.update(set(debug_handle))
return tuple(sorted(combined_debug_handles_set))
def _combine_aot_overlapped_intermediate_outputs(
aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any]
) -> Tuple[DebugHandle, Any]:
"""
Ensure the AOT combined debug_handles are the same as the runtime debug_handles (order ignored),
then pick the last intermediate output based on the runtime debug_handles
"""
# Map AOT single element debug_handles to outputs
aot_map = dict(aot_nodes)
runtime_debug_handle, _ = runtime_node

# Combine all AOT debug_handles into a list
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]

def _combine_overlapped_intermediate_outputs(
nodes: List[Tuple[DebugHandle, Any]]
) -> Tuple[DebugHandle, Any]:
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
debug_handles = [debug_handle for debug_handle, _ in nodes]
outputs = [output for _, output in nodes]
combined_debug_handle = _combine_debug_handles(debug_handles)
output = outputs[-1] # Pick the last one
return combined_debug_handle, output
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
# AOT combined debug_handle and runtime debug_handle do not match.
return (-1,), None

# Pick the last intermediate output
last_int = runtime_debug_handle[-1]
key = (last_int,)
return runtime_debug_handle, aot_map[key]


def _create_debug_handle_overlap_graph(
Expand Down Expand Up @@ -751,38 +756,52 @@ def map_runtime_aot_intermediate_outputs(

# Map only if both AOT and runtime data are present.
if len(aot_list) != 0 and len(runtime_list) != 0:
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
assert len(runtime_list) == 1
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]

# Combine aot debug handles into a single key
aot_combined_debug_handle, aot_intermediate_output = (
_combine_overlapped_intermediate_outputs(aot_list)
_combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0])
)
# Combine runtime debug handles into a single key
runtime_combined_debug_handle, runtime_intermediate_output = (
_combine_overlapped_intermediate_outputs(runtime_list)
)
# List can't be used as a key, so convert to tuple
if isinstance(aot_intermediate_output, list):

if aot_combined_debug_handle == (-1,):
# Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match.
continue

if isinstance(aot_intermediate_output, Sequence):
if not isinstance(runtime_intermediate_output, Sequence):
raise TypeError(
"runtime intermediate output should be a sequence when aot intermediate output is a sequence"
)
last_element = runtime_intermediate_output[-1]
if isinstance(last_element, list) and all(
isinstance(t, torch.Tensor) for t in last_element
):
# If the last element is a list of tensors (delegate case)
runtime_intermediate_output = last_element
elif isinstance(last_element, torch.Tensor):
# If the last element is a tensor (non-delegate case)
pass
else:
raise ValueError(
"The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence"
)
# List can't be used as a key, so convert to tuple
aot_intermediate_output = tuple(aot_intermediate_output)
# runtime follow the same format as aot, so it's safe to convert to tuple
if isinstance(runtime_intermediate_output, list):
runtime_intermediate_output = tuple(runtime_intermediate_output)

# Currently, runtime_intermediate_output logs all delegate call arguments.
# Process here to extract only the outputs.
if isinstance(aot_intermediate_output, tuple):
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
if isinstance(runtime_intermediate_output, tuple):
runtime_intermediate_output = runtime_intermediate_output[
-len(aot_intermediate_output) :
]
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
elif isinstance(runtime_intermediate_output, tuple):
elif isinstance(runtime_intermediate_output, Sequence):
# delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list
runtime_intermediate_output = runtime_intermediate_output[-1]

# Create a mapping between runtime and aot
aot_runtime_mapping[
(aot_combined_debug_handle, aot_intermediate_output)
] = (
runtime_combined_debug_handle,
runtime_debug_handle,
runtime_intermediate_output,
)

Expand Down Expand Up @@ -890,7 +909,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
if is_a_sequence and is_b_sequence:
# Ensure both sequences have the same length
if len(a) != len(b):
raise ValueError("Sequences must have the same length for comparison.")
raise ValueError(
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison."
)

# Compare each element in the sequences and return the list of results
return [comparator.compare(x, y) for x, y in zip(a, b)]
Expand All @@ -899,7 +920,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
return [comparator.compare(a, b)]
else:
# Raise an error if one is a sequence and the other is not
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
raise ValueError(
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
)


def propagate_back_debug_handle(
Expand Down
55 changes: 17 additions & 38 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_exact_match(self):
# Exact match between aot and runtime debug_handles
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {
((0, 1), 100): ((0, 1), 150),
((2, 3), 200): ((2, 3), 200),
((4, 5), 300): ((4, 5), 300),
}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
# No overlaps between aot and runtime debug_handles
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
aot_intermediate_outputs = {(0,): 100, (4,): 300}
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
# Multiple aot debug_handles map to one runtime debug_handle
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
# One aot debug_handle map to multiple runtime debug_handles
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
# Partial match between aot and runtime debug_handles will return empty
aot_intermediate_outputs = {(2,): 100, (9,): 300}
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
expected = {}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
# Complex chain (N-to-N mapping)
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
# Multiple aot debug_handles map to one runtime debug_handle
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400}
runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300}
actual = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
expected = {((2, 3, 1), 200): ((2, 3, 1), 250)}
self.assertEqual(actual, expected)

def test_map_runtime_aot_intermediate_outputs_delegated(self):
# Currently, runtime_intermediate_output logs all delegate call arguments
# Test that the map function correctly extracted out the delegated outputs
aot_intermediate_outputs = {
(1, 2): torch.tensor([4, 5]),
(3, 4): torch.tensor([10, 11, 12]),
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
(1,): torch.tensor([4, 1]),
(2,): torch.tensor([4, 5]),
(3,): torch.tensor([10, 10, 13]),
(4,): torch.tensor([10, 11, 12]),
(5,): torch.tensor([13, 14, 15, 16, 21]),
(6,): torch.tensor([13, 14, 15, 16, 17]),
}
runtime_intermediate_outputs = {
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],
Expand Down
Loading