Skip to content

Commit cf81301

Browse files
authored
Mapping between runtime and aot intermediate outputs
Differential Revision: D76442807 Pull Request resolved: #11624
1 parent 1af16cd commit cf81301

File tree

2 files changed

+258
-1
lines changed

2 files changed

+258
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import math
1010
import sys
11+
from dataclasses import dataclass
1112
from enum import Enum
1213
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
1314

@@ -72,6 +73,25 @@ class TimeScale(Enum):
7273
}
7374

7475

76+
class NodeSource(Enum):
77+
AOT = 1
78+
RUNTIME = 2
79+
80+
81+
@dataclass
82+
class NodeData:
83+
"""
84+
Each node in the graph is an instance of NodeData, which contains:
85+
- source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
86+
- debug_handle: A tuple representing the unique identifier for the output.
87+
- output: The actual output data associated with the debug handle.
88+
"""
89+
90+
source: NodeSource
91+
debug_handle: tuple[int]
92+
output: Any
93+
94+
7595
def calculate_time_scale_factor(
7696
source_time_scale: TimeScale, target_time_scale: TimeScale
7797
) -> float:
@@ -489,7 +509,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489509
"""
490510
Merge overlapping debug handles int a single key
491511
"""
492-
if not intermediate_outputs:
512+
if len(intermediate_outputs) == 0:
493513
return
494514
# Extract and normalize into (start, end, val)
495515
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
@@ -512,3 +532,161 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512532
intermediate_outputs.clear()
513533
for start, end, val in merged_intermediate_outputs:
514534
intermediate_outputs[tuple(range(start, end + 1))] = val
535+
536+
537+
def _debug_handles_have_overlap(
538+
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
539+
) -> bool:
540+
"""
541+
Check if the AOT debug handle and the runtime debug handle have any overlap.
542+
"""
543+
aot_set = set(aot_debug_hanlde)
544+
runtime_set = set(runtime_debug_handle)
545+
return len(aot_set.intersection(runtime_set)) > 0
546+
547+
548+
def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
549+
"""Combine multiple debug handles into one debug handle"""
550+
combined_debug_handles_set = set()
551+
for debug_handle in debug_handles:
552+
combined_debug_handles_set.update(set(debug_handle))
553+
return tuple(sorted(combined_debug_handles_set))
554+
555+
556+
def _combine_overlapped_intermediate_outputs(
557+
nodes: List[Tuple[Tuple[int, ...], Any]]
558+
) -> Tuple[Tuple[int, ...], Any]:
559+
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
560+
debug_handles = [debug_handle for debug_handle, _ in nodes]
561+
outputs = [output for _, output in nodes]
562+
combined_debug_handle = _combine_debug_hanldes(debug_handles)
563+
output = outputs[-1] # Pick the last one
564+
return combined_debug_handle, output
565+
566+
567+
def _create_debug_handle_overlap_graph(
568+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
569+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
570+
) -> Tuple[List[NodeData], Dict[int, List[int]]]:
571+
"""
572+
Create a graph representing overlapping debug handles between AOT and runtime outputs.
573+
574+
Edges in the graph are represented as a dictionary where:
575+
- The key is the index of a node in the nodes list.
576+
- The value is a list of indices of nodes that have overlapping debug handles with the key node.
577+
578+
Returns:
579+
- A tuple containing:
580+
- A list of NodeData instances representing the nodes in the graph.
581+
- A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
582+
"""
583+
nodes = []
584+
for debug_handle, output in aot_intermediate_outputs.items():
585+
nodes.append(NodeData(NodeSource.AOT, debug_handle, output))
586+
for debug_handle, output in runtime_intermediate_outputs.items():
587+
nodes.append(NodeData(NodeSource.RUNTIME, debug_handle, output))
588+
589+
edges = {i: [] for i in range(len(nodes))}
590+
for i in range(len(nodes)):
591+
for j in range(i + 1, len(nodes)):
592+
node_i = nodes[i]
593+
node_j = nodes[j]
594+
# Only connect nodes from different sources(aot vs runtime) that overlap
595+
if node_i.source != node_j.source and _debug_handles_have_overlap(
596+
node_i.debug_handle, node_j.debug_handle
597+
):
598+
edges[i].append(j)
599+
edges[j].append(i)
600+
return (nodes, edges)
601+
602+
603+
def _find_connected_components(
604+
nodes: List[NodeData], edges: Dict[int, List[int]]
605+
) -> List[List[int]]:
606+
"""
607+
Find groups of connected nodes in a graph using DFS.
608+
Parameters:
609+
- nodes: A list of nodes in the graph.
610+
- edges: A dictionary where each key is a node index, and the value is a list
611+
of indices of connected nodes.
612+
Returns:
613+
- A list of connected components, each represented as a list of node indices.
614+
"""
615+
visited = [False] * len(nodes)
616+
connected_components = []
617+
618+
def dfs(node_id, component):
619+
visited[node_id] = True
620+
component.append(node_id)
621+
# Iterate over all neighbors of the current node
622+
for neighbor_node_id in edges[node_id]:
623+
# If a neighbor has not been visited yet, recursively visit it
624+
if not visited[neighbor_node_id]:
625+
dfs(neighbor_node_id, component)
626+
627+
# Perform DFS on all nodes to find connected components
628+
for i in range(len(nodes)):
629+
# If a node has not been visited yet, start a new DFS from it
630+
if not visited[i]:
631+
component = []
632+
dfs(i, component)
633+
# After visiting all reachable nodes, add the current component to the list
634+
connected_components.append(component)
635+
return connected_components
636+
637+
638+
def map_runtime_aot_intermediate_outputs(
639+
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
640+
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
641+
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
642+
"""
643+
Map the runtime intermediate outputs to the AOT intermediate outputs
644+
by finding overlapping debug handles and combining them into a single debug_handle
645+
646+
Returns:
647+
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
648+
from runtime intermediate output to AOT intermediate output
649+
"""
650+
# Merge overlapping debug handles
651+
merge_overlapping_debug_handles(aot_intermediate_outputs)
652+
merge_overlapping_debug_handles(runtime_intermediate_outputs)
653+
654+
# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
655+
nodes, edges = _create_debug_handle_overlap_graph(
656+
aot_intermediate_outputs, runtime_intermediate_outputs
657+
)
658+
# Find connected(between aot and runtime) components
659+
connected_components = _find_connected_components(nodes, edges)
660+
661+
aot_runtime_mapping = {}
662+
for comp in connected_components:
663+
# Separate nodes into AOT and runtime lists based on their source,
664+
# each list is combined into a single element and mapped to each other.
665+
aot_list = [
666+
(nodes[node_id].debug_handle, nodes[node_id].output)
667+
for node_id in comp
668+
if nodes[node_id].source == NodeSource.AOT
669+
]
670+
runtime_list = [
671+
(nodes[node_id].debug_handle, nodes[node_id].output)
672+
for node_id in comp
673+
if nodes[node_id].source == NodeSource.RUNTIME
674+
]
675+
676+
# Map only if both AOT and runtime data are present.
677+
if len(aot_list) != 0 and len(runtime_list) != 0:
678+
# Combine aot debug handles into a single key
679+
aot_combined_debug_handle, aot_output = (
680+
_combine_overlapped_intermediate_outputs(aot_list)
681+
)
682+
# Combine runtime debug handles into a single key
683+
runtime_combined_debug_handle, runtime_output = (
684+
_combine_overlapped_intermediate_outputs(runtime_list)
685+
)
686+
# Create a mapping between runtime and aot
687+
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
688+
runtime_combined_debug_handle,
689+
runtime_output,
690+
)
691+
692+
return aot_runtime_mapping

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
find_populated_event,
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
37+
map_runtime_aot_intermediate_outputs,
3738
merge_overlapping_debug_handles,
3839
TimeScale,
3940
)
@@ -238,6 +239,84 @@ def test_merge_overlapping_debug_handles(self):
238239
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239240
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
240241

242+
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
243+
# When the inputs are empty, the output should also be empty
244+
aot_intermediate_outputs = {}
245+
runtime_intermediate_outputs = {}
246+
actual = map_runtime_aot_intermediate_outputs(
247+
aot_intermediate_outputs, runtime_intermediate_outputs
248+
)
249+
expected = {}
250+
self.assertEqual(actual, expected)
251+
252+
def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
253+
# Single element tuple
254+
aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300}
255+
runtime_intermediate_outputs = {(0,): 150, (1,): 250, (2,): 350}
256+
actual = map_runtime_aot_intermediate_outputs(
257+
aot_intermediate_outputs, runtime_intermediate_outputs
258+
)
259+
expected = {
260+
((0,), 100): ((0,), 150),
261+
((1,), 200): ((1,), 250),
262+
((2,), 300): ((2,), 350),
263+
}
264+
self.assertEqual(actual, expected)
265+
266+
def test_map_runtime_aot_intermediate_outputs_exact_match(self):
267+
# Exact match between aot and runtime debug_handles
268+
aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300}
269+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
270+
actual = map_runtime_aot_intermediate_outputs(
271+
aot_intermediate_outputs, runtime_intermediate_outputs
272+
)
273+
expected = {
274+
((0, 1), 100): ((0, 1), 150),
275+
((2, 3), 200): ((2, 3), 200),
276+
((4, 5), 300): ((4, 5), 300),
277+
}
278+
self.assertEqual(actual, expected)
279+
280+
def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
281+
# No overlaps between aot and runtime debug_handles
282+
aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300}
283+
runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300}
284+
actual = map_runtime_aot_intermediate_outputs(
285+
aot_intermediate_outputs, runtime_intermediate_outputs
286+
)
287+
expected = {}
288+
self.assertEqual(actual, expected)
289+
290+
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):
291+
# Multiple aot debug_handles map to one runtime debug_handle
292+
aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300}
293+
runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300}
294+
actual = map_runtime_aot_intermediate_outputs(
295+
aot_intermediate_outputs, runtime_intermediate_outputs
296+
)
297+
expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)}
298+
self.assertEqual(actual, expected)
299+
300+
def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self):
301+
# One aot debug_handle map to multiple runtime debug_handles
302+
aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300}
303+
runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300}
304+
actual = map_runtime_aot_intermediate_outputs(
305+
aot_intermediate_outputs, runtime_intermediate_outputs
306+
)
307+
expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)}
308+
self.assertEqual(actual, expected)
309+
310+
def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
311+
# Complex chain (N-to-N mapping)
312+
aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300}
313+
runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350}
314+
actual = map_runtime_aot_intermediate_outputs(
315+
aot_intermediate_outputs, runtime_intermediate_outputs
316+
)
317+
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
318+
self.assertEqual(actual, expected)
319+
241320

242321
def gen_mock_operator_graph_with_expected_map() -> (
243322
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)