Skip to content

make etrecord support export program #12336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class StrEnum(str, Enum):

class ETRecordReservedFileNames(StrEnum):
ETRECORD_IDENTIFIER = "ETRECORD_V0"
EXPORTED_PROGRAM = "exported_program"
EXPORT_GRAPH_ID = "export_graph_id"
EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
Expand All @@ -55,6 +57,8 @@ class ETRecordReservedFileNames(StrEnum):

@dataclass
class ETRecord:
exported_program: Optional[ExportedProgram] = None
export_graph_id: Optional[int] = None
edge_dialect_program: Optional[ExportedProgram] = None
graph_map: Optional[Dict[str, ExportedProgram]] = None
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
Expand All @@ -71,17 +75,20 @@ def _handle_exported_program(
assert isinstance(ep, ExportedProgram)
serialized_artifact = serialize(ep)
assert isinstance(serialized_artifact.exported_program, bytes)

method_name = f"/{method_name}" if method_name != "" else ""

etrecord_zip.writestr(
f"{module_name}/{method_name}", serialized_artifact.exported_program
f"{module_name}{method_name}", serialized_artifact.exported_program
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_constants", serialized_artifact.constants
f"{module_name}{method_name}_constants", serialized_artifact.constants
)
etrecord_zip.writestr(
f"{module_name}/{method_name}_example_inputs",
f"{module_name}{method_name}_example_inputs",
serialized_artifact.example_inputs,
)

Expand Down Expand Up @@ -188,7 +195,10 @@ def generate_etrecord(
ExecutorchProgramManager,
BundledProgram,
],
export_modules: Optional[
exported_program: Optional[
Union[ExportedProgram, Dict[str, ExportedProgram]]
] = None,
extra_recorded_export_modules: Optional[
Dict[
str,
Union[
Expand All @@ -202,7 +212,7 @@ def generate_etrecord(
"""
Generates an `ETRecord` from the given objects, serializes it and saves it to the given path.
The objects that will be serialized to an `ETRecord` are all the graph modules present
in the `export_modules` dict, the graph module present in the edge dialect program object,
in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object,
and also the graph module present in the ExecuTorch program object, which
is the closest graph module representation of what is eventually run on the device.
In addition to all the graph modules, we also serialize the program buffer, which the users
Expand All @@ -213,7 +223,8 @@ def generate_etrecord(
et_record: Path to where the `ETRecord` file will be saved to.
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model
export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module.
extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
value being the corresponding exported module. The exported graph modules can be either the
output of `torch.export()` or `exir.to_edge()`.

Expand All @@ -229,15 +240,32 @@ def generate_etrecord(
# is an etrecord when it's used later in the Developer Tools.
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")

if export_modules is not None:
for module_name, export_module in export_modules.items():
# Calculate export_graph_id before modifying exported_program
export_graph_id = 0

if exported_program is not None:
# If multiple exported programs are provided, only save forward method
if isinstance(exported_program, dict) and "forward" in exported_program:
exported_program = exported_program["forward"]

if isinstance(exported_program, ExportedProgram):
export_graph_id = id(exported_program.graph)
_handle_exported_program(
etrecord_zip,
ETRecordReservedFileNames.EXPORTED_PROGRAM,
"",
exported_program,
)

if extra_recorded_export_modules is not None:
for module_name, export_module in extra_recorded_export_modules.items():
contains_reserved_name = any(
reserved_name in module_name
for reserved_name in ETRecordReservedFileNames
)
if contains_reserved_name:
raise RuntimeError(
f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace."
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
)
_handle_export_module(etrecord_zip, export_module, module_name)

Expand Down Expand Up @@ -286,6 +314,11 @@ def generate_etrecord(
json.dumps(executorch_program.delegate_map),
)

etrecord_zip.writestr(
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
json.dumps(export_graph_id),
)


def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
"""
Expand Down Expand Up @@ -318,9 +351,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
graph_map: Dict[str, ExportedProgram] = {}
debug_handle_map = None
delegate_map = None
exported_program = None
edge_dialect_program = None
reference_outputs = None
representative_inputs = None
export_graph_id = 0

serialized_exported_program_files = set()
serialized_state_dict_files = set()
Expand All @@ -347,6 +382,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
etrecord_zip.read(f"{entry}_example_inputs"),
)
edge_dialect_program = deserialize(serialized_artifact)
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
serialized_artifact = SerializedArtifact(
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
etrecord_zip.read(f"{entry}_state_dict"),
etrecord_zip.read(f"{entry}_constants"),
etrecord_zip.read(f"{entry}_example_inputs"),
)
exported_program = deserialize(serialized_artifact)
elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
# @lint-ignore PYTHONPICKLEISBAD
reference_outputs = pickle.loads(
Expand All @@ -357,6 +400,10 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
representative_inputs = pickle.loads(
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
)
elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID:
export_graph_id = json.loads(
etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID)
)
else:
if entry.endswith("state_dict"):
serialized_state_dict_files.add(entry)
Expand All @@ -383,10 +430,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
graph_map[serialized_file] = deserialize(serialized_artifact)

return ETRecord(
exported_program=exported_program,
edge_dialect_program=edge_dialect_program,
graph_map=graph_map,
_debug_handle_map=debug_handle_map,
_delegate_map=delegate_map,
_reference_outputs=reference_outputs,
_representative_inputs=representative_inputs,
export_graph_id=export_graph_id,
)
86 changes: 83 additions & 3 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ def test_etrecord_generation(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{
extra_recorded_export_modules={
"aten_dialect_output": captured_output,
},
)

etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

self.check_graph_closeness(
etrecord.graph_map["aten_dialect_output/forward"],
captured_output.exported_program.graph_module,
Expand Down Expand Up @@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{"fail_test_case": et_output},
extra_recorded_export_modules={"fail_test_case": et_output},
)

def test_etrecord_reserved_name(self):
Expand All @@ -196,5 +197,84 @@ def test_etrecord_reserved_name(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{reserved_name: captured_output.exported_program.graph_module},
extra_recorded_export_modules={
reserved_name: captured_output.exported_program.graph_module
},
)

def test_etrecord_generation_with_exported_program(self):
"""Test that exported program can be recorded and parsed back correctly."""
captured_output, edge_output, et_output = self.get_test_model()
original_exported_program = captured_output.exported_program
expected_graph_id = id(original_exported_program.graph)

with tempfile.TemporaryDirectory() as tmpdirname:
# Generate ETRecord with exported program
generate_etrecord(
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
exported_program=original_exported_program,
)

# Parse ETRecord back
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

# Validate that the parsed exported program matches the original
self.assertIsNotNone(etrecord.exported_program)
self.check_graph_closeness(
etrecord.exported_program,
original_exported_program.graph_module,
)

# Validate other components are still present
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_output.exported_program.graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)

# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)

def test_etrecord_generation_with_exported_program_dict(self):
"""Test that exported program dictionary can be recorded and parsed back correctly."""
captured_output, edge_output, et_output = self.get_test_model()
original_exported_program = captured_output.exported_program
exported_program_dict = {"forward": original_exported_program}
expected_graph_id = id(original_exported_program.graph)

with tempfile.TemporaryDirectory() as tmpdirname:
# Generate ETRecord with exported program dictionary
generate_etrecord(
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
exported_program=exported_program_dict,
)

# Parse ETRecord back
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

# Validate that the parsed exported program matches the original
self.assertIsNotNone(etrecord.exported_program)
self.check_graph_closeness(
etrecord.exported_program,
original_exported_program.graph_module,
)

# Validate other components are still present
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_output.exported_program.graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)

# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
2 changes: 1 addition & 1 deletion devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{
extra_recorded_export_modules={
"aten_dialect_output": captured_output,
},
)
Expand Down
2 changes: 1 addition & 1 deletion devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_gen_graphs_from_etrecord(self):
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
{
extra_recorded_export_modules={
"aten_dialect_output": captured_output,
},
)
Expand Down
2 changes: 1 addition & 1 deletion examples/devtools/scripts/gen_sample_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None):
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
edge_dialect_program=edge_program,
executorch_program=et_program,
export_modules={
extra_recorded_export_modules={
"aten_dialect_output": aten_dialect,
},
)
Expand Down
Loading