Skip to content

Commit 5c364d9

Browse files
authored
Updated AOT debug_handle to operator names mapping
Differential Revision: D78118798 Pull Request resolved: #12366
1 parent 2e1bcf9 commit 5c364d9

File tree

5 files changed

+71
-44
lines changed

5 files changed

+71
-44
lines changed

devtools/inspector/_inspector.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ def _consume_etrecord(self) -> None:
11591159

11601160
def _get_aot_intermediate_outputs_and_op_names(
11611161
self,
1162-
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
1162+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11631163
"""
11641164
Capture intermediate outputs only if _representative_inputs are provided
11651165
when using bundled program to create the etrecord
@@ -1180,13 +1180,13 @@ def _get_aot_intermediate_outputs_and_op_names(
11801180
# TODO: Make it more extensible to further merge overlapping debug handles
11811181
def _get_runtime_intermediate_outputs_and_op_names(
11821182
self,
1183-
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
1183+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11841184
"""
11851185
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
11861186
from the event blocks, along with the corresponding debug handles and op names mapping.
11871187
"""
11881188
debug_handle_to_output = {}
1189-
debug_handle_to_op_name = {}
1189+
debug_handle_to_op_names = {}
11901190
for event_block in self.event_blocks:
11911191
for event in event_block.events:
11921192
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
@@ -1208,12 +1208,13 @@ def _get_runtime_intermediate_outputs_and_op_names(
12081208
event._instruction_id,
12091209
event.debug_data,
12101210
)
1211-
debug_handle_to_op_name[debug_handle] = event.name
1211+
# TODO: One debug handle can be associated with multiple op names
1212+
debug_handle_to_op_names[debug_handle] = [event.name]
12121213

12131214
merge_runtime_overlapping_debug_handles(debug_handle_to_output)
12141215
return {
12151216
k: v[1] for k, v in debug_handle_to_output.items()
1216-
}, debug_handle_to_op_name
1217+
}, debug_handle_to_op_names
12171218

12181219
def to_dataframe(
12191220
self,
@@ -1385,15 +1386,15 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
13851386
pd.DataFrame: A DataFrame listing corresponding operator outputs from
13861387
both stages and their computed numerical gaps.
13871388
"""
1388-
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
1389+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
13891390
self._get_aot_intermediate_outputs_and_op_names()
13901391
)
1391-
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_name) == 0:
1392+
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
13921393
raise ValueError(
13931394
"Missing etrecord or missing representative inputs within etrecord, both of which are required for calculating numerical gap"
13941395
)
13951396
# The runtime_op_names will be used later to map runtime debug_handle to op_name
1396-
runtime_intermediate_outputs, runtime_debug_handle_to_op_name = (
1397+
runtime_intermediate_outputs, runtime_debug_handle_to_op_names = (
13971398
self._get_runtime_intermediate_outputs_and_op_names()
13981399
)
13991400
mapping = map_runtime_aot_intermediate_outputs(
@@ -1419,11 +1420,11 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
14191420
rows.append(
14201421
{
14211422
"aot_ops": find_op_names(
1422-
aot_debug_handle, aot_debug_handle_to_op_name
1423+
aot_debug_handle, aot_debug_handle_to_op_names
14231424
),
14241425
"aot_intermediate_output": aot_intermediate_output,
14251426
"runtime_ops": find_op_names(
1426-
runtime_debug_handle, runtime_debug_handle_to_op_name
1427+
runtime_debug_handle, runtime_debug_handle_to_op_names
14271428
),
14281429
"runtime_intermediate_output": runtime_intermediate_output,
14291430
"gap": compare_intermediate_outputs(

devtools/inspector/_inspector_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -823,13 +823,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
823823

824824
def get_aot_debug_handle_to_op_name_mapping(
825825
graph_module: torch.fx.GraphModule,
826-
) -> Dict[DebugHandle, str]:
826+
) -> Dict[DebugHandle, List[str]]:
827827
"""
828828
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
829829
Parameters:
830830
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
831831
Returns:
832-
Dict[DebugHandle, str]: A dictionary mapping debug handles to operator names.
832+
Dict[DebugHandle, List[str]]: A dictionary mapping debug handles to operator names.
833833
"""
834834
node_filters = [
835835
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
@@ -845,26 +845,29 @@ def get_aot_debug_handle_to_op_name_mapping(
845845
if isinstance(debug_handle, int)
846846
else tuple(debug_handle)
847847
)
848-
debug_handle_to_op_name[key] = node.name
848+
if key in debug_handle_to_op_name:
849+
debug_handle_to_op_name[key].append(node.name)
850+
else:
851+
debug_handle_to_op_name[key] = [node.name]
849852
return debug_handle_to_op_name
850853

851854

852855
def find_op_names(
853856
target_debug_handle: DebugHandle,
854-
debug_handle_to_op_name: Dict[DebugHandle, str],
857+
debug_handle_to_op_names: Dict[DebugHandle, List[str]],
855858
) -> List[str]:
856859
"""
857860
Record the operator names only if their debug handles are part of the target debug handle.
858-
The debug handles in `debug_handle_to_op_name` have undergone merging and remain unchanged,
861+
The debug handles in `debug_handle_to_op_names` have undergone merging and remain unchanged,
859862
and this function identifies operations corresponding to these transformed handles.
860863
"""
861864
dh_set = set(target_debug_handle)
862865
result = []
863866

864-
for key_tuple, op_name in debug_handle_to_op_name.items():
867+
for key_tuple, op_name in debug_handle_to_op_names.items():
865868
# Check if key is a subset of the target_debug_handle
866869
if set(key_tuple).issubset(dh_set):
867-
result.append(op_name)
870+
result.extend(op_name)
868871

869872
return result
870873

devtools/inspector/tests/inspector_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
TimeScale,
4545
)
4646
from executorch.devtools.inspector.tests.inspector_test_utils import (
47-
check_if_debug_handle_to_op_name_match,
47+
check_if_debug_handle_to_op_names_match,
4848
check_if_final_outputs_match,
4949
model_registry,
5050
)
@@ -522,17 +522,18 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
522522
_representative_inputs=aten_model.example_inputs[0],
523523
)
524524
inspector_instance._etrecord = etrecord
525-
aot_intermediate_outputs, aot_debug_handle_to_op_name = (
525+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
526526
inspector_instance._get_aot_intermediate_outputs_and_op_names()
527527
)
528528
self.assertTrue(
529529
check_if_final_outputs_match(
530530
"ConvLinearModel", aot_intermediate_outputs
531531
)
532532
)
533+
533534
self.assertTrue(
534-
check_if_debug_handle_to_op_name_match(
535-
"ConvLinearModel", aot_debug_handle_to_op_name
535+
check_if_debug_handle_to_op_names_match(
536+
"ConvLinearModel", aot_debug_handle_to_op_names
536537
)
537538
)
538539

@@ -584,14 +585,14 @@ def test_get_runtime_intermediate_outputs_and_op_names(self):
584585
self.assertTrue(
585586
torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
586587
)
587-
self.assertEqual(op_names[(4,)], "op_3")
588+
self.assertEqual(op_names[(4,)], ["op_3"])
588589

589590
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
590591
for key in range(5, 9):
591592
self.assertIn((key,), runtime_outputs)
592593
self.assertIn((key,), op_names)
593594
self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE)
594-
self.assertEqual(op_names[(key,)], f"op_{key-1}")
595+
self.assertEqual(op_names[(key,)], [f"op_{key-1}"])
595596

596597
def test_calculate_numeric_gap(self):
597598
# Create a context manager to patch functions called by Inspector.__init__

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,22 @@ def get_expected_intermediate_outputs():
7676
}
7777

7878
@staticmethod
79-
def get_expected_debug_handle_to_op_name():
79+
def get_expected_debug_handle_to_op_names():
8080
"""
81-
Returns the expected debug handle and op name mapping for this model for the given input.
81+
Returns the expected debug handle and op names mapping for this model for the given input.
8282
"""
8383
return {
84-
(1,): "aten_convolution_default",
85-
(2,): "aten_view_copy_default",
86-
(3,): "aten_addmm_default",
87-
(4,): "aten_add_tensor",
88-
(5,): "aten_sub_tensor",
89-
(6,): "aten_mul_tensor",
90-
(7,): "aten_add_tensor_1",
91-
(8,): "aten_div_tensor",
92-
(9,): "aten_relu_default",
93-
(10,): "aten_sigmoid_default",
94-
(11,): "aten_split_with_sizes_copy_default",
84+
(1,): ["aten_convolution_default"],
85+
(2,): ["aten_view_copy_default"],
86+
(3,): ["aten_permute_copy_default", "aten_addmm_default"],
87+
(4,): ["aten_add_tensor"],
88+
(5,): ["aten_sub_tensor"],
89+
(6,): ["aten_mul_tensor"],
90+
(7,): ["aten_add_tensor_1"],
91+
(8,): ["aten_div_tensor"],
92+
(9,): ["aten_relu_default"],
93+
(10,): ["aten_sigmoid_default"],
94+
(11,): ["aten_split_with_sizes_copy_default"],
9595
}
9696

9797

@@ -129,14 +129,14 @@ def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
129129
return True
130130

131131

132-
def check_if_debug_handle_to_op_name_match(model_name, actual_debug_handle_to_op_name):
132+
def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name):
133133
"""
134134
Checks if the actual op names match the expected op names for the specified model.
135135
Returns True if all match, otherwise returns False.
136136
"""
137137
model_instance = model_registry[model_name]
138138
expected_debug_handle_to_op_name = (
139-
model_instance.get_expected_debug_handle_to_op_name()
139+
model_instance.get_expected_debug_handle_to_op_names()
140140
)
141141
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
142142
return False

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
455455
)
456456
node.meta["debug_handle"] = 1
457457
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
458-
expected_result = {(1,): "op1"}
458+
expected_result = {(1,): ["op1"]}
459459
self.assertEqual(debug_handle_to_op_name, expected_result)
460460

461461
def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
@@ -474,8 +474,8 @@ def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
474474
(
475475
1,
476476
2,
477-
): "op1",
478-
(3,): "op2",
477+
): ["op1"],
478+
(3,): ["op2"],
479479
}
480480
self.assertEqual(debug_handle_to_op_name, expected_result)
481481

@@ -555,21 +555,43 @@ def test_node_op_type_mismatch(self):
555555

556556
def test_find_op_names_empty_debug_handle(self):
557557
debug_handle = ()
558-
debug_handle_to_op_name = {(1, 2): "op1", (3, 4): "op2"}
558+
debug_handle_to_op_name = {(1, 2): ["op1"], (3, 4): ["op2"]}
559559
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
560560

561561
def test_find_op_names_no_matching_handles(self):
562562
debug_handle = (1, 2)
563-
debug_handle_to_op_name = {(3, 4): "op1", (5, 6): "op2"}
563+
debug_handle_to_op_name = {(3, 4): ["op1"], (5, 6): ["op2"]}
564564
self.assertEqual(find_op_names(debug_handle, debug_handle_to_op_name), [])
565565

566566
def test_find_op_names_matching_handles(self):
567567
debug_handle = (1, 2, 3)
568-
debug_handle_to_op_name = {(1, 2): "op1", (2, 3): "op2", (4, 5, 6): "op3"}
568+
debug_handle_to_op_name = {(1, 2): ["op1"], (2, 3): ["op2"], (4, 5, 6): ["op3"]}
569569
self.assertEqual(
570570
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
571571
)
572572

573+
def test_find_op_names_multiple_ops_single_handle(self):
574+
"""Test when a single debug handle maps to multiple operator names"""
575+
debug_handle = (1, 2, 3)
576+
debug_handle_to_op_name = {(1, 2): ["op1", "op2", "op3"], (4, 5): ["op4"]}
577+
self.assertEqual(
578+
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2", "op3"]
579+
)
580+
581+
def test_find_op_names_mixed_single_and_multiple_ops(self):
582+
"""Test mix of handles with single and multiple operator names"""
583+
debug_handle = (1, 2, 3, 4, 5)
584+
debug_handle_to_op_name = {
585+
(1, 2): ["op1"],
586+
(3,): ["op2", "op3"],
587+
(4,): ["op4"],
588+
(5,): ["op5", "op6", "op7"], # Multiple ops
589+
}
590+
self.assertEqual(
591+
find_op_names(debug_handle, debug_handle_to_op_name),
592+
["op1", "op2", "op3", "op4", "op5", "op6", "op7"],
593+
)
594+
573595
def test_compare_intermediate_outputs_sequences(self):
574596
a = [1.0, 2.0, 3.0]
575597
b = [1.0, 2.5, 3.5]

0 commit comments

Comments
 (0)