Skip to content

Commit e3cf5be

Browse files
make etrecord support export program (#12336)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12288 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/18/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/18/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/18/orig @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]>
1 parent 986b447 commit e3cf5be

File tree

5 files changed

+145
-16
lines changed

5 files changed

+145
-16
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class StrEnum(str, Enum):
4545

4646
class ETRecordReservedFileNames(StrEnum):
4747
ETRECORD_IDENTIFIER = "ETRECORD_V0"
48+
EXPORTED_PROGRAM = "exported_program"
49+
EXPORT_GRAPH_ID = "export_graph_id"
4850
EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
4951
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
5052
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
@@ -55,6 +57,8 @@ class ETRecordReservedFileNames(StrEnum):
5557

5658
@dataclass
5759
class ETRecord:
60+
exported_program: Optional[ExportedProgram] = None
61+
export_graph_id: Optional[int] = None
5862
edge_dialect_program: Optional[ExportedProgram] = None
5963
graph_map: Optional[Dict[str, ExportedProgram]] = None
6064
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
@@ -71,17 +75,20 @@ def _handle_exported_program(
7175
assert isinstance(ep, ExportedProgram)
7276
serialized_artifact = serialize(ep)
7377
assert isinstance(serialized_artifact.exported_program, bytes)
78+
79+
method_name = f"/{method_name}" if method_name != "" else ""
80+
7481
etrecord_zip.writestr(
75-
f"{module_name}/{method_name}", serialized_artifact.exported_program
82+
f"{module_name}{method_name}", serialized_artifact.exported_program
7683
)
7784
etrecord_zip.writestr(
78-
f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
85+
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
7986
)
8087
etrecord_zip.writestr(
81-
f"{module_name}/{method_name}_constants", serialized_artifact.constants
88+
f"{module_name}{method_name}_constants", serialized_artifact.constants
8289
)
8390
etrecord_zip.writestr(
84-
f"{module_name}/{method_name}_example_inputs",
91+
f"{module_name}{method_name}_example_inputs",
8592
serialized_artifact.example_inputs,
8693
)
8794

@@ -188,7 +195,10 @@ def generate_etrecord(
188195
ExecutorchProgramManager,
189196
BundledProgram,
190197
],
191-
export_modules: Optional[
198+
exported_program: Optional[
199+
Union[ExportedProgram, Dict[str, ExportedProgram]]
200+
] = None,
201+
extra_recorded_export_modules: Optional[
192202
Dict[
193203
str,
194204
Union[
@@ -202,7 +212,7 @@ def generate_etrecord(
202212
"""
203213
Generates an `ETRecord` from the given objects, serializes it and saves it to the given path.
204214
The objects that will be serialized to an `ETRecord` are all the graph modules present
205-
in the `export_modules` dict, the graph module present in the edge dialect program object,
215+
in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object,
206216
and also the graph module present in the ExecuTorch program object, which
207217
is the closest graph module representation of what is eventually run on the device.
208218
In addition to all the graph modules, we also serialize the program buffer, which the users
@@ -213,7 +223,8 @@ def generate_etrecord(
213223
et_record: Path to where the `ETRecord` file will be saved to.
214224
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
215225
executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model
216-
export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
226+
exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module.
227+
extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
217228
value being the corresponding exported module. The exported graph modules can be either the
218229
output of `torch.export()` or `exir.to_edge()`.
219230
@@ -229,15 +240,32 @@ def generate_etrecord(
229240
# is an etrecord when it's used later in the Developer Tools.
230241
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
231242

232-
if export_modules is not None:
233-
for module_name, export_module in export_modules.items():
243+
# Calculate export_graph_id before modifying exported_program
244+
export_graph_id = 0
245+
246+
if exported_program is not None:
247+
# If multiple exported programs are provided, only save forward method
248+
if isinstance(exported_program, dict) and "forward" in exported_program:
249+
exported_program = exported_program["forward"]
250+
251+
if isinstance(exported_program, ExportedProgram):
252+
export_graph_id = id(exported_program.graph)
253+
_handle_exported_program(
254+
etrecord_zip,
255+
ETRecordReservedFileNames.EXPORTED_PROGRAM,
256+
"",
257+
exported_program,
258+
)
259+
260+
if extra_recorded_export_modules is not None:
261+
for module_name, export_module in extra_recorded_export_modules.items():
234262
contains_reserved_name = any(
235263
reserved_name in module_name
236264
for reserved_name in ETRecordReservedFileNames
237265
)
238266
if contains_reserved_name:
239267
raise RuntimeError(
240-
f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace."
268+
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
241269
)
242270
_handle_export_module(etrecord_zip, export_module, module_name)
243271

@@ -286,6 +314,11 @@ def generate_etrecord(
286314
json.dumps(executorch_program.delegate_map),
287315
)
288316

317+
etrecord_zip.writestr(
318+
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
319+
json.dumps(export_graph_id),
320+
)
321+
289322

290323
def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
291324
"""
@@ -318,9 +351,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
318351
graph_map: Dict[str, ExportedProgram] = {}
319352
debug_handle_map = None
320353
delegate_map = None
354+
exported_program = None
321355
edge_dialect_program = None
322356
reference_outputs = None
323357
representative_inputs = None
358+
export_graph_id = 0
324359

325360
serialized_exported_program_files = set()
326361
serialized_state_dict_files = set()
@@ -347,6 +382,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
347382
etrecord_zip.read(f"{entry}_example_inputs"),
348383
)
349384
edge_dialect_program = deserialize(serialized_artifact)
385+
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
386+
serialized_artifact = SerializedArtifact(
387+
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
388+
etrecord_zip.read(f"{entry}_state_dict"),
389+
etrecord_zip.read(f"{entry}_constants"),
390+
etrecord_zip.read(f"{entry}_example_inputs"),
391+
)
392+
exported_program = deserialize(serialized_artifact)
350393
elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
351394
# @lint-ignore PYTHONPICKLEISBAD
352395
reference_outputs = pickle.loads(
@@ -357,6 +400,10 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
357400
representative_inputs = pickle.loads(
358401
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
359402
)
403+
elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID:
404+
export_graph_id = json.loads(
405+
etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID)
406+
)
360407
else:
361408
if entry.endswith("state_dict"):
362409
serialized_state_dict_files.add(entry)
@@ -383,10 +430,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
383430
graph_map[serialized_file] = deserialize(serialized_artifact)
384431

385432
return ETRecord(
433+
exported_program=exported_program,
386434
edge_dialect_program=edge_dialect_program,
387435
graph_map=graph_map,
388436
_debug_handle_map=debug_handle_map,
389437
_delegate_map=delegate_map,
390438
_reference_outputs=reference_outputs,
391439
_representative_inputs=representative_inputs,
440+
export_graph_id=export_graph_id,
392441
)

devtools/etrecord/tests/etrecord_test.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ def test_etrecord_generation(self):
100100
tmpdirname + "/etrecord.bin",
101101
edge_output,
102102
et_output,
103-
{
103+
extra_recorded_export_modules={
104104
"aten_dialect_output": captured_output,
105105
},
106106
)
107107

108108
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
109+
109110
self.check_graph_closeness(
110111
etrecord.graph_map["aten_dialect_output/forward"],
111112
captured_output.exported_program.graph_module,
@@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self):
184185
tmpdirname + "/etrecord.bin",
185186
edge_output,
186187
et_output,
187-
{"fail_test_case": et_output},
188+
extra_recorded_export_modules={"fail_test_case": et_output},
188189
)
189190

190191
def test_etrecord_reserved_name(self):
@@ -196,5 +197,84 @@ def test_etrecord_reserved_name(self):
196197
tmpdirname + "/etrecord.bin",
197198
edge_output,
198199
et_output,
199-
{reserved_name: captured_output.exported_program.graph_module},
200+
extra_recorded_export_modules={
201+
reserved_name: captured_output.exported_program.graph_module
202+
},
200203
)
204+
205+
def test_etrecord_generation_with_exported_program(self):
206+
"""Test that exported program can be recorded and parsed back correctly."""
207+
captured_output, edge_output, et_output = self.get_test_model()
208+
original_exported_program = captured_output.exported_program
209+
expected_graph_id = id(original_exported_program.graph)
210+
211+
with tempfile.TemporaryDirectory() as tmpdirname:
212+
# Generate ETRecord with exported program
213+
generate_etrecord(
214+
tmpdirname + "/etrecord.bin",
215+
edge_output,
216+
et_output,
217+
exported_program=original_exported_program,
218+
)
219+
220+
# Parse ETRecord back
221+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
222+
223+
# Validate that the parsed exported program matches the original
224+
self.assertIsNotNone(etrecord.exported_program)
225+
self.check_graph_closeness(
226+
etrecord.exported_program,
227+
original_exported_program.graph_module,
228+
)
229+
230+
# Validate other components are still present
231+
self.check_graph_closeness(
232+
etrecord.edge_dialect_program,
233+
edge_output.exported_program.graph_module,
234+
)
235+
self.assertEqual(
236+
etrecord._debug_handle_map,
237+
json.loads(json.dumps(et_output.debug_handle_map)),
238+
)
239+
240+
# Validate that export_graph_id matches the expected value
241+
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
242+
243+
def test_etrecord_generation_with_exported_program_dict(self):
244+
"""Test that exported program dictionary can be recorded and parsed back correctly."""
245+
captured_output, edge_output, et_output = self.get_test_model()
246+
original_exported_program = captured_output.exported_program
247+
exported_program_dict = {"forward": original_exported_program}
248+
expected_graph_id = id(original_exported_program.graph)
249+
250+
with tempfile.TemporaryDirectory() as tmpdirname:
251+
# Generate ETRecord with exported program dictionary
252+
generate_etrecord(
253+
tmpdirname + "/etrecord.bin",
254+
edge_output,
255+
et_output,
256+
exported_program=exported_program_dict,
257+
)
258+
259+
# Parse ETRecord back
260+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
261+
262+
# Validate that the parsed exported program matches the original
263+
self.assertIsNotNone(etrecord.exported_program)
264+
self.check_graph_closeness(
265+
etrecord.exported_program,
266+
original_exported_program.graph_module,
267+
)
268+
269+
# Validate other components are still present
270+
self.check_graph_closeness(
271+
etrecord.edge_dialect_program,
272+
edge_output.exported_program.graph_module,
273+
)
274+
self.assertEqual(
275+
etrecord._debug_handle_map,
276+
json.loads(json.dumps(et_output.debug_handle_map)),
277+
)
278+
279+
# Validate that export_graph_id matches the expected value
280+
self.assertEqual(etrecord.export_graph_id, expected_graph_id)

devtools/inspector/tests/inspector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self):
327327
tmpdirname + "/etrecord.bin",
328328
edge_output,
329329
et_output,
330-
{
330+
extra_recorded_export_modules={
331331
"aten_dialect_output": captured_output,
332332
},
333333
)

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_gen_graphs_from_etrecord(self):
5454
tmpdirname + "/etrecord.bin",
5555
edge_output,
5656
et_output,
57-
{
57+
extra_recorded_export_modules={
5858
"aten_dialect_output": captured_output,
5959
},
6060
)

examples/devtools/scripts/gen_sample_etrecord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None):
4141
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
4242
edge_dialect_program=edge_program,
4343
executorch_program=et_program,
44-
export_modules={
44+
extra_recorded_export_modules={
4545
"aten_dialect_output": aten_dialect,
4646
},
4747
)

0 commit comments

Comments
 (0)