Skip to content

Use same constant in runtime for unset debug handle #12385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/gasoonjia/21/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from executorch.exir.debug_handle_utils import (
DEBUG_HANDLE_KEY,
UNSET_DEBUG_HANDLE,
get_greatest_ancestor_node_identifier,
)

Expand Down Expand Up @@ -914,7 +915,7 @@ def propagate_back_debug_handle(
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).

Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.

Return: True if:
a. every debug handle in the edge dialect program has a corresponding node in the exported program
Expand All @@ -935,11 +936,6 @@ def propagate_back_debug_handle(
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
n_matched_node = 0

# debug handle for the node in the exported program but not in the edge dialect program
debug_handle_for_removed_node = (
max(export_graph_node_id_to_debug_handle.values()) + 1
)

def _find_n_match_node(node: torch.fx.Node) -> None:
nonlocal n_matched_node
if node.name in ("output", "placeholder"):
Expand All @@ -955,7 +951,7 @@ def _equip_debug_handle(node: torch.fx.Node) -> None:
if node_id in export_graph_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
else:
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)

Expand Down
9 changes: 5 additions & 4 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from executorch.devtools.inspector.numerical_comparator import L1Comparator
from executorch.exir import to_edge
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE
from torch.export import export


Expand Down Expand Up @@ -682,19 +682,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
)

# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
debug_handle_for_removed_node = 3
n_removed_nodes = 0

for node in exported_program.graph.nodes:
if node.name == "add":
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
elif node.name == "add_1":
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
elif node.op not in ("placeholder", "output"):
n_removed_nodes += 1
self.assertEqual(
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
node.meta[DEBUG_HANDLE_KEY], UNSET_DEBUG_HANDLE
)

self.assertEqual(n_removed_nodes, 2)

def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
1 change: 1 addition & 0 deletions exir/debug_handle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FROM_NODE_KEY = "from_node"
DEBUG_HANDLE_KEY = "debug_handle"

UNSET_DEBUG_HANDLE = 0

def get_greatest_ancestor_node_identifier(node: Node) -> str:
"""Get the identifier of the greatest ancestor node of the given node.
Expand Down
Loading