diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 24b9e030bf6..de6056780de 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -38,6 +38,7 @@ from executorch.exir.debug_handle_utils import ( DEBUG_HANDLE_KEY, get_greatest_ancestor_node_identifier, + UNSET_DEBUG_HANDLE, ) from executorch.exir.graph_module import bfs_trace_with_node_process @@ -917,7 +918,7 @@ def propagate_back_debug_handle( where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass). Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. - The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping. + The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. Return: True if: a. every debug handle in the edge dialect program has a corresponding node in the exported program @@ -938,11 +939,6 @@ def propagate_back_debug_handle( # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle n_matched_node = 0 - # debug handle for the node in the exported program but not in the edge dialect program - debug_handle_for_removed_node = ( - max(export_graph_node_id_to_debug_handle.values()) + 1 - ) - def _find_n_match_node(node: torch.fx.Node) -> None: nonlocal n_matched_node if node.name in ("output", "placeholder"): @@ -958,7 +954,7 @@ def _equip_debug_handle(node: torch.fx.Node) -> None: if node_id in export_graph_node_id_to_debug_handle: node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] else: - node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 2416fd89838..987e13e986d 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -47,7 +47,7 @@ ) from executorch.devtools.inspector.numerical_comparator import L1Comparator from executorch.exir import to_edge -from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE from torch.export import export @@ -710,8 +710,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) ) - # only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three - debug_handle_for_removed_node = 3 + n_removed_nodes = 0 for node in exported_program.graph.nodes: if node.name == "add": @@ -719,9 +718,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: elif node.name == "add_1": self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2) elif node.op not in ("placeholder", "output"): - self.assertEqual( - node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node - ) + n_removed_nodes += 1 + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], UNSET_DEBUG_HANDLE) + + self.assertEqual(n_removed_nodes, 2) def gen_mock_operator_graph_with_expected_map() -> ( diff --git a/exir/debug_handle_utils.py b/exir/debug_handle_utils.py index d1a70fcd213..771e47c79db 100644 --- a/exir/debug_handle_utils.py +++ b/exir/debug_handle_utils.py @@ -9,6 +9,8 @@ FROM_NODE_KEY = "from_node" DEBUG_HANDLE_KEY = "debug_handle" +UNSET_DEBUG_HANDLE = 0 + def get_greatest_ancestor_node_identifier(node: Node) -> str: """Get the identifier of the greatest ancestor node of the given node.