diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index ffb81a8e41a..014148f2a13 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -45,6 +45,8 @@ class StrEnum(str, Enum): class ETRecordReservedFileNames(StrEnum): ETRECORD_IDENTIFIER = "ETRECORD_V0" + EXPORTED_PROGRAM = "exported_program" + EXPORT_GRAPH_ID = "export_graph_id" EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program" ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module" DEBUG_HANDLE_MAP_NAME = "debug_handle_map" @@ -55,6 +57,8 @@ class ETRecordReservedFileNames(StrEnum): @dataclass class ETRecord: + exported_program: Optional[ExportedProgram] = None + export_graph_id: Optional[int] = None edge_dialect_program: Optional[ExportedProgram] = None graph_map: Optional[Dict[str, ExportedProgram]] = None _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None @@ -71,17 +75,20 @@ def _handle_exported_program( assert isinstance(ep, ExportedProgram) serialized_artifact = serialize(ep) assert isinstance(serialized_artifact.exported_program, bytes) + + method_name = f"/{method_name}" if method_name != "" else "" + etrecord_zip.writestr( - f"{module_name}/{method_name}", serialized_artifact.exported_program + f"{module_name}{method_name}", serialized_artifact.exported_program ) etrecord_zip.writestr( - f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict + f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict ) etrecord_zip.writestr( - f"{module_name}/{method_name}_constants", serialized_artifact.constants + f"{module_name}{method_name}_constants", serialized_artifact.constants ) etrecord_zip.writestr( - f"{module_name}/{method_name}_example_inputs", + f"{module_name}{method_name}_example_inputs", serialized_artifact.example_inputs, ) @@ -188,7 +195,10 @@ def generate_etrecord( ExecutorchProgramManager, BundledProgram, ], - export_modules: Optional[ + exported_program: Optional[ + Union[ExportedProgram, Dict[str, ExportedProgram]] + ] = None, + extra_recorded_export_modules: Optional[ Dict[ str, Union[ @@ -202,7 +212,7 @@ def generate_etrecord( """ Generates an `ETRecord` from the given objects, serializes it and saves it to the given path. The objects that will be serialized to an `ETRecord` are all the graph modules present - in the `export_modules` dict, the graph module present in the edge dialect program object, + in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object, and also the graph module present in the ExecuTorch program object, which is the closest graph module representation of what is eventually run on the device. In addition to all the graph modules, we also serialize the program buffer, which the users @@ -213,7 +223,8 @@ def generate_etrecord( et_record: Path to where the `ETRecord` file will be saved to. edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge() executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model - export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the + exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module. + extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the value being the corresponding exported module. The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. @@ -229,15 +240,32 @@ def generate_etrecord( # is an etrecord when it's used later in the Developer Tools. etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") - if export_modules is not None: - for module_name, export_module in export_modules.items(): + # Calculate export_graph_id before modifying exported_program + export_graph_id = 0 + + if exported_program is not None: + # If multiple exported programs are provided, only save forward method + if isinstance(exported_program, dict) and "forward" in exported_program: + exported_program = exported_program["forward"] + + if isinstance(exported_program, ExportedProgram): + export_graph_id = id(exported_program.graph) + _handle_exported_program( + etrecord_zip, + ETRecordReservedFileNames.EXPORTED_PROGRAM, + "", + exported_program, + ) + + if extra_recorded_export_modules is not None: + for module_name, export_module in extra_recorded_export_modules.items(): contains_reserved_name = any( reserved_name in module_name for reserved_name in ETRecordReservedFileNames ) if contains_reserved_name: raise RuntimeError( - f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace." + f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace." ) _handle_export_module(etrecord_zip, export_module, module_name) @@ -286,6 +314,11 @@ def generate_etrecord( json.dumps(executorch_program.delegate_map), ) + etrecord_zip.writestr( + ETRecordReservedFileNames.EXPORT_GRAPH_ID, + json.dumps(export_graph_id), + ) + def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 """ @@ -318,9 +351,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map: Dict[str, ExportedProgram] = {} debug_handle_map = None delegate_map = None + exported_program = None edge_dialect_program = None reference_outputs = None representative_inputs = None + export_graph_id = 0 serialized_exported_program_files = set() serialized_state_dict_files = set() @@ -347,6 +382,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 etrecord_zip.read(f"{entry}_example_inputs"), ) edge_dialect_program = deserialize(serialized_artifact) + elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM: + serialized_artifact = SerializedArtifact( + etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM), + etrecord_zip.read(f"{entry}_state_dict"), + etrecord_zip.read(f"{entry}_constants"), + etrecord_zip.read(f"{entry}_example_inputs"), + ) + exported_program = deserialize(serialized_artifact) elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS: # @lint-ignore PYTHONPICKLEISBAD reference_outputs = pickle.loads( @@ -357,6 +400,10 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 representative_inputs = pickle.loads( etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS) ) + elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID: + export_graph_id = json.loads( + etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID) + ) else: if entry.endswith("state_dict"): serialized_state_dict_files.add(entry) @@ -383,10 +430,12 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map[serialized_file] = deserialize(serialized_artifact) return ETRecord( + exported_program=exported_program, edge_dialect_program=edge_dialect_program, graph_map=graph_map, _debug_handle_map=debug_handle_map, _delegate_map=delegate_map, _reference_outputs=reference_outputs, _representative_inputs=representative_inputs, + export_graph_id=export_graph_id, ) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index dd1d40e0292..85d19c5e952 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -100,12 +100,13 @@ def test_etrecord_generation(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + self.check_graph_closeness( etrecord.graph_map["aten_dialect_output/forward"], captured_output.exported_program.graph_module, @@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - {"fail_test_case": et_output}, + extra_recorded_export_modules={"fail_test_case": et_output}, ) def test_etrecord_reserved_name(self): @@ -196,5 +197,84 @@ def test_etrecord_reserved_name(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - {reserved_name: captured_output.exported_program.graph_module}, + extra_recorded_export_modules={ + reserved_name: captured_output.exported_program.graph_module + }, ) + + def test_etrecord_generation_with_exported_program(self): + """Test that exported program can be recorded and parsed back correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + expected_graph_id = id(original_exported_program.graph) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Generate ETRecord with exported program + generate_etrecord( + tmpdirname + "/etrecord.bin", + edge_output, + et_output, + exported_program=original_exported_program, + ) + + # Parse ETRecord back + etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + + # Validate that the parsed exported program matches the original + self.assertIsNotNone(etrecord.exported_program) + self.check_graph_closeness( + etrecord.exported_program, + original_exported_program.graph_module, + ) + + # Validate other components are still present + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + + # Validate that export_graph_id matches the expected value + self.assertEqual(etrecord.export_graph_id, expected_graph_id) + + def test_etrecord_generation_with_exported_program_dict(self): + """Test that exported program dictionary can be recorded and parsed back correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + exported_program_dict = {"forward": original_exported_program} + expected_graph_id = id(original_exported_program.graph) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Generate ETRecord with exported program dictionary + generate_etrecord( + tmpdirname + "/etrecord.bin", + edge_output, + et_output, + exported_program=exported_program_dict, + ) + + # Parse ETRecord back + etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + + # Validate that the parsed exported program matches the original + self.assertIsNotNone(etrecord.exported_program) + self.check_graph_closeness( + etrecord.exported_program, + original_exported_program.graph_module, + ) + + # Validate other components are still present + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + + # Validate that export_graph_id matches the expected value + self.assertEqual(etrecord.export_graph_id, expected_graph_id) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 7246dd3bac0..1b4051cb813 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 9c0b5fc7fc5..47113910e98 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -54,7 +54,7 @@ def test_gen_graphs_from_etrecord(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) diff --git a/examples/devtools/scripts/gen_sample_etrecord.py b/examples/devtools/scripts/gen_sample_etrecord.py index a6b3d487251..e5b46cdede5 100644 --- a/examples/devtools/scripts/gen_sample_etrecord.py +++ b/examples/devtools/scripts/gen_sample_etrecord.py @@ -41,7 +41,7 @@ def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None): (DEFAULT_OUTPUT_PATH if not output_path else output_path), edge_dialect_program=edge_program, executorch_program=et_program, - export_modules={ + extra_recorded_export_modules={ "aten_dialect_output": aten_dialect, }, )