29
29
from executorch .exir .serde .export_serialize import SerializedArtifact
30
30
from executorch .exir .serde .serialize import deserialize , serialize
31
31
32
+ ProgramInput = List [Value ]
32
33
ProgramOutput = List [Value ]
33
34
34
35
try :
@@ -49,6 +50,7 @@ class ETRecordReservedFileNames(StrEnum):
49
50
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
50
51
DELEGATE_MAP_NAME = "delegate_map"
51
52
REFERENCE_OUTPUTS = "reference_outputs"
53
+ REPRESENTATIVE_INPUTS = "representative_inputs"
52
54
53
55
54
56
@dataclass
@@ -60,6 +62,7 @@ class ETRecord:
60
62
Dict [str , Dict [int , Dict [str , Union [str , _DelegateDebugIdentifierMap ]]]]
61
63
] = None
62
64
_reference_outputs : Optional [Dict [str , List [ProgramOutput ]]] = None
65
+ _representative_inputs : Optional [List [ProgramOutput ]] = None
63
66
64
67
65
68
def _handle_exported_program (
@@ -157,6 +160,24 @@ def _get_reference_outputs(
157
160
return reference_outputs
158
161
159
162
163
+ def _get_representative_inputs (
164
+ bundled_program : BundledProgram ,
165
+ ) -> List [ProgramInput ]:
166
+ """
167
+ Extracts out the inputs from the bundled program, keyed by the method names.
168
+ """
169
+ for method_test_suite in bundled_program .method_test_suites :
170
+ if method_test_suite .method_name == "forward" :
171
+ if not method_test_suite .test_cases :
172
+ raise ValueError (
173
+ "The 'forward' method is defined, but no corresponding input test cases are provided."
174
+ )
175
+ # Get first example input from the forward method
176
+ test_case = method_test_suite .test_cases [0 ]
177
+ return test_case .inputs
178
+ raise ValueError ("No 'forward' method found in the bundled program." )
179
+
180
+
160
181
def generate_etrecord (
161
182
et_record : Union [str , os .PathLike , BinaryIO , IO [bytes ]],
162
183
edge_dialect_program : Union [EdgeProgramManager , ExirExportedProgram ],
@@ -244,6 +265,13 @@ def generate_etrecord(
244
265
# @lint-ignore PYTHONPICKLEISBAD
245
266
pickle .dumps (reference_outputs ),
246
267
)
268
+
269
+ representative_inputs = _get_representative_inputs (executorch_program )
270
+ etrecord_zip .writestr (
271
+ ETRecordReservedFileNames .REPRESENTATIVE_INPUTS ,
272
+ # @lint-ignore PYTHONPICKLEISBAD
273
+ pickle .dumps (representative_inputs ),
274
+ )
247
275
executorch_program = executorch_program .executorch_program
248
276
249
277
etrecord_zip .writestr (
@@ -290,6 +318,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
290
318
delegate_map = None
291
319
edge_dialect_program = None
292
320
reference_outputs = None
321
+ representative_inputs = None
293
322
294
323
serialized_exported_program_files = set ()
295
324
serialized_state_dict_files = set ()
@@ -321,6 +350,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
321
350
reference_outputs = pickle .loads (
322
351
etrecord_zip .read (ETRecordReservedFileNames .REFERENCE_OUTPUTS )
323
352
)
353
+ elif entry == ETRecordReservedFileNames .REPRESENTATIVE_INPUTS :
354
+ # @lint-ignore PYTHONPICKLEISBAD
355
+ representative_inputs = pickle .loads (
356
+ etrecord_zip .read (ETRecordReservedFileNames .REPRESENTATIVE_INPUTS )
357
+ )
324
358
else :
325
359
if entry .endswith ("state_dict" ):
326
360
serialized_state_dict_files .add (entry )
@@ -352,4 +386,5 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
352
386
_debug_handle_map = debug_handle_map ,
353
387
_delegate_map = delegate_map ,
354
388
_reference_outputs = reference_outputs ,
389
+ _representative_inputs = representative_inputs ,
355
390
)
0 commit comments