diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 24b9e030bf6..040d664a808 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -625,23 +625,28 @@ def _debug_handles_have_overlap( return len(aot_set.intersection(runtime_set)) > 0 -def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle: - """Combine multiple debug handles into one debug handle""" - combined_debug_handles_set = set() - for debug_handle in debug_handles: - combined_debug_handles_set.update(set(debug_handle)) - return tuple(sorted(combined_debug_handles_set)) +def _combine_aot_overlapped_intermediate_outputs( + aot_nodes: List[Tuple[DebugHandle, Any]], runtime_node: Tuple[DebugHandle, Any] +) -> Tuple[DebugHandle, Any]: + """ + Ensure the AOT combined debug_handles are the same as the runtime debug_handles (order ignored), + then pick the last intermediate output based on the runtime debug_handles + """ + # Map AOT single element debug_handles to outputs + aot_map = dict(aot_nodes) + runtime_debug_handle, _ = runtime_node + # Combine all AOT debug_handles into a list + aot_combined_debug_handle = [t[0] for t in aot_map.keys()] -def _combine_overlapped_intermediate_outputs( - nodes: List[Tuple[DebugHandle, Any]] -) -> Tuple[DebugHandle, Any]: - """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output""" - debug_handles = [debug_handle for debug_handle, _ in nodes] - outputs = [output for _, output in nodes] - combined_debug_handle = _combine_debug_handles(debug_handles) - output = outputs[-1] # Pick the last one - return combined_debug_handle, output + if set(aot_combined_debug_handle) != set(runtime_debug_handle): + # AOT combined debug_handle and runtime debug_handle do not match. + return (-1,), None + + # Pick the last intermediate output + last_int = runtime_debug_handle[-1] + key = (last_int,) + return runtime_debug_handle, aot_map[key] def _create_debug_handle_overlap_graph( @@ -751,38 +756,52 @@ def map_runtime_aot_intermediate_outputs( # Map only if both AOT and runtime data are present. if len(aot_list) != 0 and len(runtime_list) != 0: + # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element. + # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes. + # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings. + assert len(runtime_list) == 1 + runtime_debug_handle, runtime_intermediate_output = runtime_list[0] + # Combine aot debug handles into a single key aot_combined_debug_handle, aot_intermediate_output = ( - _combine_overlapped_intermediate_outputs(aot_list) + _combine_aot_overlapped_intermediate_outputs(aot_list, runtime_list[0]) ) - # Combine runtime debug handles into a single key - runtime_combined_debug_handle, runtime_intermediate_output = ( - _combine_overlapped_intermediate_outputs(runtime_list) - ) - # List can't be used as a key, so convert to tuple - if isinstance(aot_intermediate_output, list): + + if aot_combined_debug_handle == (-1,): + # Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match. + continue + + if isinstance(aot_intermediate_output, Sequence): + if not isinstance(runtime_intermediate_output, Sequence): + raise TypeError( + "runtime intermediate output should be a sequence when aot intermediate output is a sequence" + ) + last_element = runtime_intermediate_output[-1] + if isinstance(last_element, list) and all( + isinstance(t, torch.Tensor) for t in last_element + ): + # If the last element is a list of tensors (delegate case) + runtime_intermediate_output = last_element + elif isinstance(last_element, torch.Tensor): + # If the last element is a tensor (non-delegate case) + pass + else: + raise ValueError( + "The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence" + ) + # List can't be used as a key, so convert to tuple aot_intermediate_output = tuple(aot_intermediate_output) - # runtime follow the same format as aot, so it's safe to convert to tuple - if isinstance(runtime_intermediate_output, list): runtime_intermediate_output = tuple(runtime_intermediate_output) - # Currently, runtime_intermediate_output logs all delegate call arguments. - # Process here to extract only the outputs. - if isinstance(aot_intermediate_output, tuple): - # If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output - if isinstance(runtime_intermediate_output, tuple): - runtime_intermediate_output = runtime_intermediate_output[ - -len(aot_intermediate_output) : - ] - # If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element - elif isinstance(runtime_intermediate_output, tuple): + elif isinstance(runtime_intermediate_output, Sequence): + # delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list runtime_intermediate_output = runtime_intermediate_output[-1] # Create a mapping between runtime and aot aot_runtime_mapping[ (aot_combined_debug_handle, aot_intermediate_output) ] = ( - runtime_combined_debug_handle, + runtime_debug_handle, runtime_intermediate_output, ) @@ -890,7 +909,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: if is_a_sequence and is_b_sequence: # Ensure both sequences have the same length if len(a) != len(b): - raise ValueError("Sequences must have the same length for comparison.") + raise ValueError( + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison." + ) # Compare each element in the sequences and return the list of results return [comparator.compare(x, y) for x, y in zip(a, b)] @@ -899,7 +920,9 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: return [comparator.compare(a, b)] else: # Raise an error if one is a sequence and the other is not - raise ValueError("Both inputs must be sequences or both must be non-sequences.") + raise ValueError( + f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences." + ) def propagate_back_debug_handle( diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 29f6f7ce4a8..2d5ff242e22 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -302,23 +302,9 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self): } self.assertEqual(actual, expected) - def test_map_runtime_aot_intermediate_outputs_exact_match(self): - # Exact match between aot and runtime debug_handles - aot_intermediate_outputs = {(0, 1): 100, (2, 3): 200, (4, 5): 300} - runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300} - actual = map_runtime_aot_intermediate_outputs( - aot_intermediate_outputs, runtime_intermediate_outputs - ) - expected = { - ((0, 1), 100): ((0, 1), 150), - ((2, 3), 200): ((2, 3), 200), - ((4, 5), 300): ((4, 5), 300), - } - self.assertEqual(actual, expected) - def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): # No overlaps between aot and runtime debug_handles - aot_intermediate_outputs = {(0, 1): 100, (4, 5): 300} + aot_intermediate_outputs = {(0,): 100, (4,): 300} runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs @@ -326,43 +312,36 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): expected = {} self.assertEqual(actual, expected) - def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self): - # Multiple aot debug_handles map to one runtime debug_handle - aot_intermediate_outputs = {(0, 1, 2): 100, (3, 4): 300} - runtime_intermediate_outputs = {(1, 2, 3): 250, (8, 9): 300} - actual = map_runtime_aot_intermediate_outputs( - aot_intermediate_outputs, runtime_intermediate_outputs - ) - expected = {((0, 1, 2, 3, 4), 300): ((1, 2, 3), 250)} - self.assertEqual(actual, expected) - - def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime(self): - # One aot debug_handle map to multiple runtime debug_handles - aot_intermediate_outputs = {(0, 1, 2, 3, 4): 100, (8, 9): 300} - runtime_intermediate_outputs = {(0, 1): 150, (2, 3): 200, (4, 5): 300} + def test_map_runtime_aot_intermediate_outputs_partial_match(self): + # Partial match between aot and runtime debug_handles will return empty + aot_intermediate_outputs = {(2,): 100, (9,): 300} + runtime_intermediate_outputs = {(2, 3): 200, (8, 9): 300} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - expected = {((0, 1, 2, 3, 4), 100): ((0, 1, 2, 3, 4, 5), 300)} + expected = {} self.assertEqual(actual, expected) - def test_map_runtime_aot_intermediate_outputs_complex_chain(self): - # Complex chain (N-to-N mapping) - aot_intermediate_outputs = {(1, 2): 100, (3, 4): 200, (5, 6): 300} - runtime_intermediate_outputs = {(2, 3): 150, (4, 5): 250, (6, 7): 350} + def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self): + # Multiple aot debug_handles map to one runtime debug_handle + aot_intermediate_outputs = {(0,): 100, (1,): 200, (2,): 300, (3,): 400} + runtime_intermediate_outputs = {(2, 3, 1): 250, (8, 9): 300} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)} + expected = {((2, 3, 1), 200): ((2, 3, 1), 250)} self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_delegated(self): # Currently, runtime_intermediate_output logs all delegate call arguments # Test that the map function correctly extracted out the delegated outputs aot_intermediate_outputs = { - (1, 2): torch.tensor([4, 5]), - (3, 4): torch.tensor([10, 11, 12]), - (5, 6): torch.tensor([13, 14, 15, 16, 17]), + (1,): torch.tensor([4, 1]), + (2,): torch.tensor([4, 5]), + (3,): torch.tensor([10, 10, 13]), + (4,): torch.tensor([10, 11, 12]), + (5,): torch.tensor([13, 14, 15, 16, 21]), + (6,): torch.tensor([13, 14, 15, 16, 17]), } runtime_intermediate_outputs = { (1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],