Skip to content

propagate debug handle from edge dialect graph back to exported graph #12337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class StrEnum(str, Enum):

class ETRecordReservedFileNames(StrEnum):
ETRECORD_IDENTIFIER = "ETRECORD_V0"
EXPORTED_PROGRAM = "exported_program"
EXPORT_GRAPH_ID = "export_graph_id"
EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
Expand All @@ -55,6 +57,8 @@ class ETRecordReservedFileNames(StrEnum):

@dataclass
class ETRecord:
exported_program: Optional[ExportedProgram] = None
export_graph_id: Optional[int] = None
edge_dialect_program: Optional[ExportedProgram] = None
graph_map: Optional[Dict[str, ExportedProgram]] = None
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
Expand All @@ -71,17 +75,20 @@ def _handle_exported_program(
assert isinstance(ep, ExportedProgram)
serialized_artifact = serialize(ep)
assert isinstance(serialized_artifact.exported_program, bytes)

method_name = f"/{method_name}" if method_name != "" else ""

etrecord_zip.writestr(
f"{module_name}/{method_name}", serialized_artifact.exported_program
f"{module_name}{method_name}", serialized_artifact.exported_program
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_constants", serialized_artifact.constants
f"{module_name}{method_name}_constants", serialized_artifact.constants
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_example_inputs",
f"{module_name}{method_name}_example_inputs",
serialized_artifact.example_inputs,
)

Expand Down Expand Up @@ -188,7 +195,10 @@ def generate_etrecord(
ExecutorchProgramManager,
BundledProgram,
],
export_modules: Optional[
exported_program: Optional[
Union[ExportedProgram, Dict[str, ExportedProgram]]
] = None,
extra_recorded_export_modules: Optional[
Dict[
str,
Union[
Expand All @@ -202,7 +212,7 @@ def generate_etrecord(
"""
Generates an `ETRecord` from the given objects, serializes it and saves it to the given path.
The objects that will be serialized to an `ETRecord` are all the graph modules present
in the `export_modules` dict, the graph module present in the edge dialect program object,
in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object,
and also the graph module present in the ExecuTorch program object, which
is the closest graph module representation of what is eventually run on the device.
In addition to all the graph modules, we also serialize the program buffer, which the users
Expand All @@ -213,7 +223,8 @@ def generate_etrecord(
et_record: Path to where the `ETRecord` file will be saved to.
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model
export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module.
extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
value being the corresponding exported module. The exported graph modules can be either the
output of `torch.export()` or `exir.to_edge()`.

Expand All @@ -229,15 +240,32 @@ def generate_etrecord(
# is an etrecord when it's used later in the Developer Tools.
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")

if export_modules is not None:
for module_name, export_module in export_modules.items():
# Calculate export_graph_id before modifying exported_program
export_graph_id = 0

if exported_program is not None:
# If multiple exported programs are provided, only save forward method
if isinstance(exported_program, dict) and "forward" in exported_program:
exported_program = exported_program["forward"]

if isinstance(exported_program, ExportedProgram):
export_graph_id = id(exported_program.graph)
_handle_exported_program(
etrecord_zip,
ETRecordReservedFileNames.EXPORTED_PROGRAM,
"",
exported_program,
)

if extra_recorded_export_modules is not None:
for module_name, export_module in extra_recorded_export_modules.items():
contains_reserved_name = any(
reserved_name in module_name
for reserved_name in ETRecordReservedFileNames
)
if contains_reserved_name:
raise RuntimeError(
f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace."
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
)
_handle_export_module(etrecord_zip, export_module, module_name)

Expand Down Expand Up @@ -286,6 +314,11 @@ def generate_etrecord(
json.dumps(executorch_program.delegate_map),
)

etrecord_zip.writestr(
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
json.dumps(export_graph_id),
)


def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
"""
Expand Down Expand Up @@ -318,9 +351,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
graph_map: Dict[str, ExportedProgram] = {}
debug_handle_map = None
delegate_map = None
exported_program = None
edge_dialect_program = None
reference_outputs = None
representative_inputs = None
export_graph_id = 0

serialized_exported_program_files = set()
serialized_state_dict_files = set()
Expand All @@ -347,6 +382,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
etrecord_zip.read(f"{entry}_example_inputs"),
)
edge_dialect_program = deserialize(serialized_artifact)
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
serialized_artifact = SerializedArtifact(
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
etrecord_zip.read(f"{entry}_state_dict"),
etrecord_zip.read(f"{entry}_constants"),
etrecord_zip.read(f"{entry}_example_inputs"),
)
exported_program = deserialize(serialized_artifact)
elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
# @lint-ignore PYTHONPICKLEISBAD
reference_outputs = pickle.loads(
Expand All @@ -357,6 +400,10 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
representative_inputs = pickle.loads(
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
)
elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID:
export_graph_id = json.loads(
etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID)
)
else:
if entry.endswith("state_dict"):
serialized_state_dict_files.add(entry)
Expand All @@ -383,10 +430,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
graph_map[serialized_file] = deserialize(serialized_artifact)

return ETRecord(
exported_program=exported_program,
edge_dialect_program=edge_dialect_program,
graph_map=graph_map,
_debug_handle_map=debug_handle_map,
_delegate_map=delegate_map,
_reference_outputs=reference_outputs,
_representative_inputs=representative_inputs,
export_graph_id=export_graph_id,
)
86 changes: 83 additions & 3 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ def test_etrecord_generation(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{
extra_recorded_export_modules={
"aten_dialect_output": captured_output,
},
)

etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

self.check_graph_closeness(
etrecord.graph_map["aten_dialect_output/forward"],
captured_output.exported_program.graph_module,
Expand Down Expand Up @@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{"fail_test_case": et_output},
extra_recorded_export_modules={"fail_test_case": et_output},
)

def test_etrecord_reserved_name(self):
Expand All @@ -196,5 +197,84 @@ def test_etrecord_reserved_name(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{reserved_name: captured_output.exported_program.graph_module},
extra_recorded_export_modules={
reserved_name: captured_output.exported_program.graph_module
},
)

def test_etrecord_generation_with_exported_program(self):
"""Test that exported program can be recorded and parsed back correctly."""
captured_output, edge_output, et_output = self.get_test_model()
original_exported_program = captured_output.exported_program
expected_graph_id = id(original_exported_program.graph)

with tempfile.TemporaryDirectory() as tmpdirname:
# Generate ETRecord with exported program
generate_etrecord(
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
exported_program=original_exported_program,
)

# Parse ETRecord back
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

# Validate that the parsed exported program matches the original
self.assertIsNotNone(etrecord.exported_program)
self.check_graph_closeness(
etrecord.exported_program,
original_exported_program.graph_module,
)

# Validate other components are still present
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_output.exported_program.graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)

# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)

def test_etrecord_generation_with_exported_program_dict(self):
"""Test that exported program dictionary can be recorded and parsed back correctly."""
captured_output, edge_output, et_output = self.get_test_model()
original_exported_program = captured_output.exported_program
exported_program_dict = {"forward": original_exported_program}
expected_graph_id = id(original_exported_program.graph)

with tempfile.TemporaryDirectory() as tmpdirname:
# Generate ETRecord with exported program dictionary
generate_etrecord(
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
exported_program=exported_program_dict,
)

# Parse ETRecord back
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

# Validate that the parsed exported program matches the original
self.assertIsNotNone(etrecord.exported_program)
self.check_graph_closeness(
etrecord.exported_program,
original_exported_program.graph_module,
)

# Validate other components are still present
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_output.exported_program.graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)

# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
77 changes: 77 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@
from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc
from executorch.devtools.etrecord import ETRecord

from executorch.exir.debug_handle_utils import (
DEBUG_HANDLE_KEY,
get_greatest_ancestor_node_identifier,
)

from executorch.exir.graph_module import bfs_trace_with_node_process

from tabulate import tabulate

from torch.export import ExportedProgram

FORWARD = "forward"
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"

Expand Down Expand Up @@ -888,3 +897,71 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
else:
# Raise an error if one is a sequence and the other is not
raise ValueError("Both inputs must be sequences or both must be non-sequences.")


def propagate_back_debug_handle(
exported_program: ExportedProgram,
exported_program_graph_id: int,
edge_dialect_program: ExportedProgram,
) -> bool:
"""
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
of operator tracing.

e.g.
export program: op1 -> op2 -> op3
edge dialect program: op1_0 -> op3_0 -> op3_1
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).

Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.

Return: True if:
a. every debug handle in the edge dialect program has a corresponding node in the exported program
b. the exported program is the greatest ancestor of the edge dialect program

Otherwise, return False.
"""

# 1. set up a mapping from debug handle to identifier of export program's node
# using edge dialect program nodes' debug handles and from_node info
export_graph_node_id_to_debug_handle = {
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
for node in edge_dialect_program.graph.nodes
if node.op not in ("placeholder", "output")
}

# 2. equip debug handle to the exported program's nodes using the mapping
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
n_matched_node = 0

# debug handle for the node in the exported program but not in the edge dialect program
debug_handle_for_removed_node = (
max(export_graph_node_id_to_debug_handle.values()) + 1
)

def _find_n_match_node(node: torch.fx.Node) -> None:
nonlocal n_matched_node
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
n_matched_node += 1

def _equip_debug_handle(node: torch.fx.Node) -> None:
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
else:
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)

# if any node in the edge dialect program has no corresponding node in the exported program, match failed
if n_matched_node != len(export_graph_node_id_to_debug_handle):
return False

bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
return True
2 changes: 1 addition & 1 deletion devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{
extra_recorded_export_modules={
"aten_dialect_output": captured_output,
},
)
Expand Down
Loading
Loading