8
8
9
9
import math
10
10
import sys
11
+ from dataclasses import dataclass
11
12
from enum import Enum
12
13
from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
13
14
@@ -72,6 +73,25 @@ class TimeScale(Enum):
72
73
}
73
74
74
75
76
+ class NodeSource (Enum ):
77
+ AOT = 1
78
+ RUNTIME = 2
79
+
80
+
81
+ @dataclass
82
+ class NodeData :
83
+ """
84
+ Each node in the graph is an instance of NodeData, which contains:
85
+ - source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
86
+ - debug_handle: A tuple representing the unique identifier for the output.
87
+ - output: The actual output data associated with the debug handle.
88
+ """
89
+
90
+ source : NodeSource
91
+ debug_handle : tuple [int ]
92
+ output : Any
93
+
94
+
75
95
def calculate_time_scale_factor (
76
96
source_time_scale : TimeScale , target_time_scale : TimeScale
77
97
) -> float :
@@ -489,7 +509,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489
509
"""
490
510
Merge overlapping debug handles int a single key
491
511
"""
492
- if not intermediate_outputs :
512
+ if len ( intermediate_outputs ) == 0 :
493
513
return
494
514
# Extract and normalize into (start, end, val)
495
515
intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
@@ -512,3 +532,161 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512
532
intermediate_outputs .clear ()
513
533
for start , end , val in merged_intermediate_outputs :
514
534
intermediate_outputs [tuple (range (start , end + 1 ))] = val
535
+
536
+
537
+ def _debug_handles_have_overlap (
538
+ aot_debug_hanlde : Tuple [int , ...], runtime_debug_handle : Tuple [int , ...]
539
+ ) -> bool :
540
+ """
541
+ Check if the AOT debug handle and the runtime debug handle have any overlap.
542
+ """
543
+ aot_set = set (aot_debug_hanlde )
544
+ runtime_set = set (runtime_debug_handle )
545
+ return len (aot_set .intersection (runtime_set )) > 0
546
+
547
+
548
+ def _combine_debug_hanldes (debug_handles : List [Tuple [int , ...]]) -> Tuple [int , ...]:
549
+ """Combine multiple debug handles into one debug handle"""
550
+ combined_debug_handles_set = set ()
551
+ for debug_handle in debug_handles :
552
+ combined_debug_handles_set .update (set (debug_handle ))
553
+ return tuple (sorted (combined_debug_handles_set ))
554
+
555
+
556
+ def _combine_overlapped_intermediate_outputs (
557
+ nodes : List [Tuple [Tuple [int , ...], Any ]]
558
+ ) -> Tuple [Tuple [int , ...], Any ]:
559
+ """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
560
+ debug_handles = [debug_handle for debug_handle , _ in nodes ]
561
+ outputs = [output for _ , output in nodes ]
562
+ combined_debug_handle = _combine_debug_hanldes (debug_handles )
563
+ output = outputs [- 1 ] # Pick the last one
564
+ return combined_debug_handle , output
565
+
566
+
567
+ def _create_debug_handle_overlap_graph (
568
+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
569
+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
570
+ ) -> Tuple [List [NodeData ], Dict [int , List [int ]]]:
571
+ """
572
+ Create a graph representing overlapping debug handles between AOT and runtime outputs.
573
+
574
+ Edges in the graph are represented as a dictionary where:
575
+ - The key is the index of a node in the nodes list.
576
+ - The value is a list of indices of nodes that have overlapping debug handles with the key node.
577
+
578
+ Returns:
579
+ - A tuple containing:
580
+ - A list of NodeData instances representing the nodes in the graph.
581
+ - A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
582
+ """
583
+ nodes = []
584
+ for debug_handle , output in aot_intermediate_outputs .items ():
585
+ nodes .append (NodeData (NodeSource .AOT , debug_handle , output ))
586
+ for debug_handle , output in runtime_intermediate_outputs .items ():
587
+ nodes .append (NodeData (NodeSource .RUNTIME , debug_handle , output ))
588
+
589
+ edges = {i : [] for i in range (len (nodes ))}
590
+ for i in range (len (nodes )):
591
+ for j in range (i + 1 , len (nodes )):
592
+ node_i = nodes [i ]
593
+ node_j = nodes [j ]
594
+ # Only connect nodes from different sources(aot vs runtime) that overlap
595
+ if node_i .source != node_j .source and _debug_handles_have_overlap (
596
+ node_i .debug_handle , node_j .debug_handle
597
+ ):
598
+ edges [i ].append (j )
599
+ edges [j ].append (i )
600
+ return (nodes , edges )
601
+
602
+
603
+ def _find_connected_components (
604
+ nodes : List [NodeData ], edges : Dict [int , List [int ]]
605
+ ) -> List [List [int ]]:
606
+ """
607
+ Find groups of connected nodes in a graph using DFS.
608
+ Parameters:
609
+ - nodes: A list of nodes in the graph.
610
+ - edges: A dictionary where each key is a node index, and the value is a list
611
+ of indices of connected nodes.
612
+ Returns:
613
+ - A list of connected components, each represented as a list of node indices.
614
+ """
615
+ visited = [False ] * len (nodes )
616
+ connected_components = []
617
+
618
+ def dfs (node_id , component ):
619
+ visited [node_id ] = True
620
+ component .append (node_id )
621
+ # Iterate over all neighbors of the current node
622
+ for neighbor_node_id in edges [node_id ]:
623
+ # If a neighbor has not been visited yet, recursively visit it
624
+ if not visited [neighbor_node_id ]:
625
+ dfs (neighbor_node_id , component )
626
+
627
+ # Perform DFS on all nodes to find connected components
628
+ for i in range (len (nodes )):
629
+ # If a node has not been visited yet, start a new DFS from it
630
+ if not visited [i ]:
631
+ component = []
632
+ dfs (i , component )
633
+ # After visiting all reachable nodes, add the current component to the list
634
+ connected_components .append (component )
635
+ return connected_components
636
+
637
+
638
+ def map_runtime_aot_intermediate_outputs (
639
+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
640
+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
641
+ ) -> Dict [Tuple [Tuple [int , ...], Any ], Tuple [Tuple [int , ...], Any ]]:
642
+ """
643
+ Map the runtime intermediate outputs to the AOT intermediate outputs
644
+ by finding overlapping debug handles and combining them into a single debug_handle
645
+
646
+ Returns:
647
+ Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
648
+ from runtime intermediate output to AOT intermediate output
649
+ """
650
+ # Merge overlapping debug handles
651
+ merge_overlapping_debug_handles (aot_intermediate_outputs )
652
+ merge_overlapping_debug_handles (runtime_intermediate_outputs )
653
+
654
+ # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
655
+ nodes , edges = _create_debug_handle_overlap_graph (
656
+ aot_intermediate_outputs , runtime_intermediate_outputs
657
+ )
658
+ # Find connected(between aot and runtime) components
659
+ connected_components = _find_connected_components (nodes , edges )
660
+
661
+ aot_runtime_mapping = {}
662
+ for comp in connected_components :
663
+ # Separate nodes into AOT and runtime lists based on their source,
664
+ # each list is combined into a single element and mapped to each other.
665
+ aot_list = [
666
+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
667
+ for node_id in comp
668
+ if nodes [node_id ].source == NodeSource .AOT
669
+ ]
670
+ runtime_list = [
671
+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
672
+ for node_id in comp
673
+ if nodes [node_id ].source == NodeSource .RUNTIME
674
+ ]
675
+
676
+ # Map only if both AOT and runtime data are present.
677
+ if len (aot_list ) != 0 and len (runtime_list ) != 0 :
678
+ # Combine aot debug handles into a single key
679
+ aot_combined_debug_handle , aot_output = (
680
+ _combine_overlapped_intermediate_outputs (aot_list )
681
+ )
682
+ # Combine runtime debug handles into a single key
683
+ runtime_combined_debug_handle , runtime_output = (
684
+ _combine_overlapped_intermediate_outputs (runtime_list )
685
+ )
686
+ # Create a mapping between runtime and aot
687
+ aot_runtime_mapping [(aot_combined_debug_handle , aot_output )] = (
688
+ runtime_combined_debug_handle ,
689
+ runtime_output ,
690
+ )
691
+
692
+ return aot_runtime_mapping
0 commit comments