Skip to content

Commit 0a9b48f

Browse files
committed
Use same constant in runtime for unset debug handle
In the runtime we have a contant number for unset debug handle. This diff bring that to python env for usage. Differential Revision: [D78132322](https://our.internmc.facebook.com/intern/diff/D78132322/) ghstack-source-id: 295530857 Pull Request resolved: #12385
1 parent 31ba959 commit 0a9b48f

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
40+
UNSET_DEBUG_HANDLE,
4041
get_greatest_ancestor_node_identifier,
4142
)
4243

@@ -914,7 +915,7 @@ def propagate_back_debug_handle(
914915
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).
915916
916917
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
917-
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.
918+
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
918919
919920
Return: True if:
920921
a. every debug handle in the edge dialect program has a corresponding node in the exported program
@@ -935,11 +936,6 @@ def propagate_back_debug_handle(
935936
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
936937
n_matched_node = 0
937938

938-
# debug handle for the node in the exported program but not in the edge dialect program
939-
debug_handle_for_removed_node = (
940-
max(export_graph_node_id_to_debug_handle.values()) + 1
941-
)
942-
943939
def _find_n_match_node(node: torch.fx.Node) -> None:
944940
nonlocal n_matched_node
945941
if node.name in ("output", "placeholder"):
@@ -955,7 +951,7 @@ def _equip_debug_handle(node: torch.fx.Node) -> None:
955951
if node_id in export_graph_node_id_to_debug_handle:
956952
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
957953
else:
958-
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node
954+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
959955

960956
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
961957

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.devtools.inspector.numerical_comparator import L1Comparator
4949
from executorch.exir import to_edge
50-
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
50+
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE
5151
from torch.export import export
5252

5353

@@ -682,19 +682,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
682682
)
683683
)
684684

685-
# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
686-
debug_handle_for_removed_node = 3
685+
n_removed_nodes = 0
687686

688687
for node in exported_program.graph.nodes:
689688
if node.name == "add":
690689
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
691690
elif node.name == "add_1":
692691
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
693692
elif node.op not in ("placeholder", "output"):
693+
n_removed_nodes += 1
694694
self.assertEqual(
695-
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
695+
node.meta[DEBUG_HANDLE_KEY], UNSET_DEBUG_HANDLE
696696
)
697697

698+
self.assertEqual(n_removed_nodes, 2)
698699

699700
def gen_mock_operator_graph_with_expected_map() -> (
700701
Tuple[OperatorGraph, Dict[int, OperatorNode]]

exir/debug_handle_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
FROM_NODE_KEY = "from_node"
1010
DEBUG_HANDLE_KEY = "debug_handle"
1111

12+
UNSET_DEBUG_HANDLE = 0
1213

1314
def get_greatest_ancestor_node_identifier(node: Node) -> str:
1415
"""Get the identifier of the greatest ancestor node of the given node.

0 commit comments

Comments
 (0)