Skip to content

Commit 3419b46

Browse files
propagate debug handle from edge dialect graph back to exported graph (#12337)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12330 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/19/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/19/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/18/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/19/orig @diff-train-skip-merge --------- Co-authored-by: gasoonjia <[email protected]>
1 parent e3cf5be commit 3419b46

File tree

6 files changed

+232
-21
lines changed

6 files changed

+232
-21
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,17 @@
3535
from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc
3636
from executorch.devtools.etrecord import ETRecord
3737

38+
from executorch.exir.debug_handle_utils import (
39+
DEBUG_HANDLE_KEY,
40+
get_greatest_ancestor_node_identifier,
41+
)
42+
43+
from executorch.exir.graph_module import bfs_trace_with_node_process
44+
3845
from tabulate import tabulate
3946

47+
from torch.export import ExportedProgram
48+
4049
FORWARD = "forward"
4150
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
4251

@@ -888,3 +897,71 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
888897
else:
889898
# Raise an error if one is a sequence and the other is not
890899
raise ValueError("Both inputs must be sequences or both must be non-sequences.")
900+
901+
902+
def propagate_back_debug_handle(
903+
exported_program: ExportedProgram,
904+
exported_program_graph_id: int,
905+
edge_dialect_program: ExportedProgram,
906+
) -> bool:
907+
"""
908+
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
909+
of operator tracing.
910+
911+
e.g.
912+
export program: op1 -> op2 -> op3
913+
edge dialect program: op1_0 -> op3_0 -> op3_1
914+
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).
915+
916+
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
917+
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.
918+
919+
Return: True if:
920+
a. every debug handle in the edge dialect program has a corresponding node in the exported program
921+
b. the exported program is the greatest ancestor of the edge dialect program
922+
923+
Otherwise, return False.
924+
"""
925+
926+
# 1. set up a mapping from debug handle to identifier of export program's node
927+
# using edge dialect program nodes' debug handles and from_node info
928+
export_graph_node_id_to_debug_handle = {
929+
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
930+
for node in edge_dialect_program.graph.nodes
931+
if node.op not in ("placeholder", "output")
932+
}
933+
934+
# 2. equip debug handle to the exported program's nodes using the mapping
935+
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
936+
n_matched_node = 0
937+
938+
# debug handle for the node in the exported program but not in the edge dialect program
939+
debug_handle_for_removed_node = (
940+
max(export_graph_node_id_to_debug_handle.values()) + 1
941+
)
942+
943+
def _find_n_match_node(node: torch.fx.Node) -> None:
944+
nonlocal n_matched_node
945+
if node.name in ("output", "placeholder"):
946+
return
947+
node_id = f"{node.name}.{exported_program_graph_id}"
948+
if node_id in export_graph_node_id_to_debug_handle:
949+
n_matched_node += 1
950+
951+
def _equip_debug_handle(node: torch.fx.Node) -> None:
952+
if node.name in ("output", "placeholder"):
953+
return
954+
node_id = f"{node.name}.{exported_program_graph_id}"
955+
if node_id in export_graph_node_id_to_debug_handle:
956+
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
957+
else:
958+
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node
959+
960+
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
961+
962+
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
963+
if n_matched_node != len(export_graph_node_id_to_debug_handle):
964+
return False
965+
966+
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
967+
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import unittest
1111
from typing import Dict, Tuple
1212

13-
import torch
13+
import executorch.exir.tests.models as models
1414

15+
import torch
1516
from executorch.devtools import generate_etrecord, parse_etrecord
1617

1718
from executorch.devtools.debug_format.base_schema import (
@@ -41,9 +42,13 @@
4142
map_runtime_aot_intermediate_outputs,
4243
merge_runtime_overlapping_debug_handles,
4344
NodeFilter,
45+
propagate_back_debug_handle,
4446
TimeScale,
4547
)
4648
from executorch.devtools.inspector.numerical_comparator import L1Comparator
49+
from executorch.exir import to_edge
50+
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
51+
from torch.export import export
4752

4853

4954
class TestInspectorUtils(unittest.TestCase):
@@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
583588
with self.assertRaises(ValueError):
584589
compare_intermediate_outputs(a, b, L1Comparator())
585590

591+
def test_equip_debug_handle_to_export_program_success(self):
592+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
593+
# Create a test model
594+
model = models.FeedForwardBlock(5, 10)
595+
inputs = (torch.rand(5, 5),)
596+
597+
# Export the model
598+
exported_program = export(model, inputs)
599+
export_graph_id = id(exported_program.graph)
600+
601+
# Convert to edge dialect
602+
edge_dialect_program = to_edge(exported_program).exported_program()
603+
604+
# Call propagate_back_debug_handle
605+
result = propagate_back_debug_handle(
606+
exported_program, export_graph_id, edge_dialect_program
607+
)
608+
609+
self.assertTrue(result)
610+
611+
# Check that debug handles are properly equipped in the exported program
612+
exported_program_debug_handles = []
613+
for node in exported_program.graph.nodes:
614+
if node.op not in ("placeholder", "output"):
615+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
616+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
617+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
618+
619+
edge_dialect_program_debug_handles = []
620+
for node in edge_dialect_program.graph.nodes:
621+
if node.op not in ("placeholder", "output"):
622+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
623+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
624+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
625+
626+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
627+
# So they should have the same debug handle
628+
self.assertEqual(
629+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
630+
)
631+
self.assertEqual(
632+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
633+
)
634+
635+
def test_equip_debug_handle_to_export_program_failure(self):
636+
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
637+
# Create a test model
638+
model = models.FeedForwardBlock(5, 10)
639+
inputs = (torch.rand(5, 5),)
640+
641+
exported_program = export(model, inputs)
642+
edge_dialect_program = to_edge(exported_program).exported_program()
643+
644+
# Create a different exported program (reexport) to cause mismatch
645+
reexported_program = export(model, inputs)
646+
reexport_graph_id = id(reexported_program.graph)
647+
648+
# Call propagate_back_debug_handle with mismatched programs
649+
# This should return False because the reexported program has different node identifiers
650+
result = propagate_back_debug_handle(
651+
reexported_program, reexport_graph_id, edge_dialect_program
652+
)
653+
654+
# Check that it returns False due to mismatch
655+
self.assertFalse(result)
656+
657+
def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self):
658+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge"""
659+
660+
class M(torch.nn.Module):
661+
"""
662+
Simple model with ops that will be removed in to_edge
663+
"""
664+
665+
def __init__(self) -> None:
666+
super().__init__()
667+
668+
def forward(self, x: torch.Tensor) -> torch.Tensor:
669+
x = x + 1
670+
x = x.to(x.dtype)
671+
x = x + 1
672+
return x
673+
674+
inputs = (torch.rand(5, 5),)
675+
exported_program = torch.export.export(M(), inputs)
676+
export_graph_id = id(exported_program.graph)
677+
edge_dialect_program = to_edge(exported_program).exported_program()
678+
679+
self.assertTrue(
680+
propagate_back_debug_handle(
681+
exported_program, export_graph_id, edge_dialect_program
682+
)
683+
)
684+
685+
# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
686+
debug_handle_for_removed_node = 3
687+
688+
for node in exported_program.graph.nodes:
689+
if node.name == "add":
690+
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
691+
elif node.name == "add_1":
692+
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
693+
elif node.op not in ("placeholder", "output"):
694+
self.assertEqual(
695+
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
696+
)
697+
586698

587699
def gen_mock_operator_graph_with_expected_map() -> (
588700
Tuple[OperatorGraph, Dict[int, OperatorNode]]

exir/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,11 @@ python_library(
277277
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
278278
],
279279
)
280+
281+
python_library(
282+
name = "debug_handle_utils",
283+
srcs = ["debug_handle_utils.py"],
284+
deps = [
285+
"//caffe2:torch",
286+
],
287+
)

exir/debug_handle_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch.fx import Node
8+
9+
FROM_NODE_KEY = "from_node"
10+
DEBUG_HANDLE_KEY = "debug_handle"
11+
12+
13+
def get_greatest_ancestor_node_identifier(node: Node) -> str:
14+
"""Get the identifier of the greatest ancestor node of the given node.
15+
16+
The identifier is the concatenation of the node name and graph id of the
17+
greatest ancestor node, where the graph id is the unique id for every graph
18+
module in the export flow and node name is unique within the same graph module.
19+
"""
20+
21+
node_source = node.meta[FROM_NODE_KEY]
22+
node_source = node_source[-1]
23+
24+
while len(node_source.from_node) > 0:
25+
node_source = node_source.from_node[-1]
26+
27+
return f"{node_source.name}.{str(node_source.graph_id)}"

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ python_library(
342342
],
343343
deps = [
344344
"//caffe2:torch",
345+
"//executorch/exir:debug_handle_utils",
345346
"//executorch/exir:graph_module",
346347
"//executorch/exir:pass_base",
347348
],

exir/passes/debug_handle_generator_pass.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
from typing import Dict
88

9+
from executorch.exir.debug_handle_utils import (
10+
DEBUG_HANDLE_KEY,
11+
FROM_NODE_KEY,
12+
get_greatest_ancestor_node_identifier,
13+
)
914
from executorch.exir.graph_module import bfs_trace_with_node_process
1015
from executorch.exir.pass_base import ExportPass
1116
from torch.export import ExportedProgram
@@ -21,27 +26,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
2126
greatest ancestor node in the export flow.
2227
"""
2328

24-
FROM_NODE_KEY = "from_node"
25-
DEBUG_HANDLE_KEY = "debug_handle"
26-
2729
source_node_id_to_debug_handle: Dict[str, int] = {}
2830

29-
def _get_greatest_ancestor_node_identifier(node: Node) -> str:
30-
"""Get the identifier of the greatest ancestor node of the given node.
31-
32-
The identifier is the concatenation of the node name and graph id of the
33-
greatest ancestor node, where the graph id is the unique id for every graph
34-
module in the export flow and node name is unique within the same graph module.
35-
"""
36-
37-
node_source = node.meta[FROM_NODE_KEY]
38-
node_source = node_source[-1]
39-
40-
while len(node_source.from_node) > 0:
41-
node_source = node_source.from_node[-1]
42-
43-
return node_source.name + str(node_source.graph_id)
44-
4531
def _extract_debug_handles_from_node(node: Node) -> None:
4632
"""
4733
Generate a debug handle based on node's oldest ancestor node's name
@@ -56,7 +42,7 @@ def _extract_debug_handles_from_node(node: Node) -> None:
5642
FROM_NODE_KEY in node.meta
5743
), f"Node {node} does not have meta key {FROM_NODE_KEY}"
5844

59-
greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node)
45+
greatest_ancestor_node_id = get_greatest_ancestor_node_identifier(node)
6046

6147
debug_handle = (
6248
len(source_node_id_to_debug_handle) + 1

0 commit comments

Comments
 (0)