Skip to content

Commit 47d1592

Browse files
authored
Mapping AOT debug_handles to op names
Differential Revision: D77244175 Pull Request resolved: #11930
1 parent 6e706f2 commit 47d1592

File tree

5 files changed

+166
-23
lines changed

5 files changed

+166
-23
lines changed

devtools/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ python_library(
5656
"_intermediate_output_capturer.py",
5757
],
5858
deps = [
59+
"//executorch/devtools/inspector:inspector_utils",
5960
],
6061
)
6162

devtools/inspector/_inspector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
FORWARD,
5353
gen_etdump_object,
5454
gen_graphs_from_etrecord,
55+
get_aot_debug_handle_to_op_name_mapping,
5556
inflate_runtime_output,
5657
is_debug_output,
5758
is_inference_output_equal,
@@ -1084,6 +1085,7 @@ def __init__(
10841085
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10851086
self._enable_module_hierarchy = enable_module_hierarchy
10861087
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
1088+
self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None
10871089
self._consume_etrecord()
10881090

10891091
def _consume_etrecord(self) -> None:
@@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None:
11501152
return
11511153
export_program = self._etrecord.edge_dialect_program
11521154
graph_module = export_program.module()
1155+
self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping(
1156+
graph_module
1157+
)
11531158
capturer = IntermediateOutputCapturer(graph_module)
11541159
self._aot_intermediate_outputs = capturer.run_and_capture(
11551160
self._etrecord._representative_inputs

devtools/inspector/_inspector_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ class NodeData:
9393
output: Any
9494

9595

96+
class NodeFilter:
97+
"""
98+
A class used to filter nodes based on extensible criteria.
99+
Attributes:
100+
metadata_key (str): The key to look for in the node's metadata.
101+
op_type (str): The operation code to match.
102+
exclude_ops (List[str]): A list of operations to exclude from the filter.
103+
"""
104+
105+
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
106+
self.metadata_key = metadata_key
107+
self.op_type = op_type
108+
self.exclude_ops = exclude_ops
109+
110+
def matches(self, node: torch.fx.Node) -> bool:
111+
return (
112+
node.meta.get(self.metadata_key) is not None
113+
and node.op == self.op_type
114+
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
115+
)
116+
117+
96118
def calculate_time_scale_factor(
97119
source_time_scale: TimeScale, target_time_scale: TimeScale
98120
) -> float:
@@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
734756
if torch.isnan(input_tensor).any():
735757
input_tensor = torch.nan_to_num(input_tensor)
736758
return input_tensor
759+
760+
761+
def get_aot_debug_handle_to_op_name_mapping(
762+
graph_module: torch.fx.GraphModule,
763+
) -> Dict[Tuple[int, ...], str]:
764+
"""
765+
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
766+
Parameters:
767+
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
768+
Returns:
769+
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
770+
"""
771+
node_filters = [
772+
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
773+
]
774+
775+
debug_handle_to_op_name = {}
776+
for node in graph_module.graph.nodes:
777+
if all(filter.matches(node) for filter in node_filters):
778+
debug_handle = node.meta["debug_handle"]
779+
# Convert the debug handle to a tuple to use as a dictionary key
780+
key = (
781+
(debug_handle,)
782+
if isinstance(debug_handle, int)
783+
else tuple(debug_handle)
784+
)
785+
debug_handle_to_op_name[key] = node.name
786+
return debug_handle_to_op_name

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,14 @@
77
# pyre-unsafe
88

99

10-
from typing import Any, Dict, List, Tuple
10+
from typing import Any, Dict, Tuple
1111

1212
import torch
13+
from executorch.devtools.inspector._inspector_utils import NodeFilter
1314
from torch.fx import GraphModule
1415
from torch.fx.interpreter import Interpreter
1516

1617

17-
class NodeFilter:
18-
"""
19-
A class used to filter nodes based on extensible criteria.
20-
Attributes:
21-
metadata_key (str): The key to look for in the node's metadata.
22-
op_type (str): The operation code to match.
23-
exclude_ops (List[str]): A list of operations to exclude from the filter.
24-
"""
25-
26-
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
27-
self.metadata_key = metadata_key
28-
self.op_type = op_type
29-
self.exclude_ops = exclude_ops
30-
31-
def matches(self, node: torch.fx.Node) -> bool:
32-
return (
33-
node.meta.get(self.metadata_key) is not None
34-
and node.op == self.op_type
35-
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
36-
)
37-
38-
3918
class IntermediateOutputCapturer(Interpreter):
4019
"""
4120
A class that captures intermediate outputs from a PyTorch graph module.

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
EDGE_DIALECT_GRAPH_KEY,
3535
find_populated_event,
3636
gen_graphs_from_etrecord,
37+
get_aot_debug_handle_to_op_name_mapping,
3738
is_inference_output_equal,
3839
map_runtime_aot_intermediate_outputs,
3940
merge_overlapping_debug_handles,
41+
NodeFilter,
4042
TimeScale,
4143
)
4244

@@ -364,6 +366,112 @@ class X:
364366
msg = str(cm.exception)
365367
self.assertIn("Cannot convert value of type", msg)
366368

369+
def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self):
370+
# Create a simple graph module with one node
371+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
372+
node = graph_module.graph.create_node(
373+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
374+
)
375+
node.meta["debug_handle"] = 1
376+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
377+
expected_result = {(1,): "op1"}
378+
self.assertEqual(debug_handle_to_op_name, expected_result)
379+
380+
def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self):
381+
# Create a simple graph module with two nodes
382+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
383+
node1 = graph_module.graph.create_node(
384+
"call_function", target=torch.mul, args=(), kwargs={}, name="op1"
385+
)
386+
node1.meta["debug_handle"] = (1, 2)
387+
node2 = graph_module.graph.create_node(
388+
"call_function", target=torch.mul, args=(), kwargs={}, name="op2"
389+
)
390+
node2.meta["debug_handle"] = 3
391+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
392+
expected_result = {
393+
(
394+
1,
395+
2,
396+
): "op1",
397+
(3,): "op2",
398+
}
399+
self.assertEqual(debug_handle_to_op_name, expected_result)
400+
401+
def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self):
402+
# Create a simple graph module with no nodes
403+
graph_module = torch.fx.GraphModule({}, torch.fx.Graph())
404+
debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module)
405+
expected_result = {}
406+
self.assertEqual(debug_handle_to_op_name, expected_result)
407+
408+
def test_node_filter_match(self):
409+
node_filter = NodeFilter(
410+
"debug_handle", "call_function", exclude_ops=["getitem"]
411+
)
412+
413+
# Create a mock node that matches the filter criteria
414+
mock_node = torch.fx.Node(
415+
graph=torch.fx.Graph(),
416+
name="mock_node",
417+
op="call_function",
418+
target=torch.nn.functional.relu,
419+
args=(),
420+
kwargs={},
421+
)
422+
mock_node.meta["debug_handle"] = (1, 2)
423+
# Test that the filter matches the mock node
424+
self.assertTrue(node_filter.matches(mock_node))
425+
426+
def test_node_filter_key_mismatch(self):
427+
node_filter = NodeFilter(
428+
"debug_handle", "call_function", exclude_ops=["getitem"]
429+
)
430+
mock_node_metadata_key_mismatch = torch.fx.Node(
431+
graph=torch.fx.Graph(),
432+
name="mock_node_metadata_key_mismatch",
433+
op="call_function",
434+
target=torch.nn.functional.relu,
435+
args=(),
436+
kwargs={},
437+
)
438+
# Test that the filter doesn't match the mock node (meta doesn't have debug_handle key)
439+
self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch))
440+
441+
def test_node_filter_ops_mismatch(self):
442+
node_filter = NodeFilter(
443+
"debug_handle", "call_function", exclude_ops=["getitem"]
444+
)
445+
446+
mock_node_exclude_ops_mismatch = torch.fx.Node(
447+
graph=torch.fx.Graph(),
448+
name="getitem",
449+
op="call_function",
450+
target=torch.nn.functional.relu,
451+
args=(),
452+
kwargs={},
453+
)
454+
mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2)
455+
# Test that the filter doesn't match the mock node (exclude_ops mismatch)
456+
self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch))
457+
458+
def test_node_op_type_mismatch(self):
459+
node_filter = NodeFilter(
460+
"debug_handle", "call_function", exclude_ops=["getitem"]
461+
)
462+
463+
mock_node_op_type_mismatch = torch.fx.Node(
464+
graph=torch.fx.Graph(),
465+
name="mock_node_op_type_mismatch",
466+
op="get_attr",
467+
target="torch.nn.functional.relu",
468+
args=(),
469+
kwargs={},
470+
)
471+
mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2)
472+
# Test that the filter doesn't match the mock node (op_type mismatch)
473+
self.assertFalse(node_filter.matches(mock_node_op_type_mismatch))
474+
367475

368476
def gen_mock_operator_graph_with_expected_map() -> (
369477
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)