Skip to content

Commit f8a3fd8

Browse files
authored
Save the representative intputs into the ETRecord object
Differential Revision: D75637400 Pull Request resolved: #11244
1 parent b308544 commit f8a3fd8

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from executorch.exir.serde.export_serialize import SerializedArtifact
3030
from executorch.exir.serde.serialize import deserialize, serialize
3131

32+
ProgramInput = List[Value]
3233
ProgramOutput = List[Value]
3334

3435
try:
@@ -49,6 +50,7 @@ class ETRecordReservedFileNames(StrEnum):
4950
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
5051
DELEGATE_MAP_NAME = "delegate_map"
5152
REFERENCE_OUTPUTS = "reference_outputs"
53+
REPRESENTATIVE_INPUTS = "representative_inputs"
5254

5355

5456
@dataclass
@@ -60,6 +62,7 @@ class ETRecord:
6062
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
6163
] = None
6264
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None
65+
_representative_inputs: Optional[List[ProgramOutput]] = None
6366

6467

6568
def _handle_exported_program(
@@ -157,6 +160,24 @@ def _get_reference_outputs(
157160
return reference_outputs
158161

159162

163+
def _get_representative_inputs(
164+
bundled_program: BundledProgram,
165+
) -> List[ProgramInput]:
166+
"""
167+
Extracts out the inputs from the bundled program, keyed by the method names.
168+
"""
169+
for method_test_suite in bundled_program.method_test_suites:
170+
if method_test_suite.method_name == "forward":
171+
if not method_test_suite.test_cases:
172+
raise ValueError(
173+
"The 'forward' method is defined, but no corresponding input test cases are provided."
174+
)
175+
# Get first example input from the forward method
176+
test_case = method_test_suite.test_cases[0]
177+
return test_case.inputs
178+
raise ValueError("No 'forward' method found in the bundled program.")
179+
180+
160181
def generate_etrecord(
161182
et_record: Union[str, os.PathLike, BinaryIO, IO[bytes]],
162183
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
@@ -244,6 +265,13 @@ def generate_etrecord(
244265
# @lint-ignore PYTHONPICKLEISBAD
245266
pickle.dumps(reference_outputs),
246267
)
268+
269+
representative_inputs = _get_representative_inputs(executorch_program)
270+
etrecord_zip.writestr(
271+
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
272+
# @lint-ignore PYTHONPICKLEISBAD
273+
pickle.dumps(representative_inputs),
274+
)
247275
executorch_program = executorch_program.executorch_program
248276

249277
etrecord_zip.writestr(
@@ -290,6 +318,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
290318
delegate_map = None
291319
edge_dialect_program = None
292320
reference_outputs = None
321+
representative_inputs = None
293322

294323
serialized_exported_program_files = set()
295324
serialized_state_dict_files = set()
@@ -321,6 +350,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
321350
reference_outputs = pickle.loads(
322351
etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS)
323352
)
353+
elif entry == ETRecordReservedFileNames.REPRESENTATIVE_INPUTS:
354+
# @lint-ignore PYTHONPICKLEISBAD
355+
representative_inputs = pickle.loads(
356+
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
357+
)
324358
else:
325359
if entry.endswith("state_dict"):
326360
serialized_state_dict_files.add(entry)
@@ -352,4 +386,5 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
352386
_debug_handle_map=debug_handle_map,
353387
_delegate_map=delegate_map,
354388
_reference_outputs=reference_outputs,
389+
_representative_inputs=representative_inputs,
355390
)

devtools/etrecord/tests/etrecord_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.devtools.etrecord import generate_etrecord, parse_etrecord
2020
from executorch.devtools.etrecord._etrecord import (
2121
_get_reference_outputs,
22+
_get_representative_inputs,
2223
ETRecordReservedFileNames,
2324
)
2425
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
@@ -135,15 +136,25 @@ def test_etrecord_generation_with_bundled_program(self):
135136
)
136137
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
137138

138-
expected = etrecord._reference_outputs
139-
actual = _get_reference_outputs(bundled_program)
139+
expected_inputs = etrecord._representative_inputs
140+
actual_inputs = _get_representative_inputs(bundled_program)
140141
# assertEqual() gives "RuntimeError: Boolean value of Tensor with more than one value is ambiguous" when comparing tensors,
141142
# so we use torch.equal() to compare the tensors one by one.
143+
for expected, actual in zip(expected_inputs, actual_inputs):
144+
self.assertTrue(torch.equal(expected[0], actual[0]))
145+
self.assertTrue(torch.equal(expected[1], actual[1]))
146+
147+
expected_outputs = etrecord._reference_outputs
148+
actual_outputs = _get_reference_outputs(bundled_program)
142149
self.assertTrue(
143-
torch.equal(expected["forward"][0][0], actual["forward"][0][0])
150+
torch.equal(
151+
expected_outputs["forward"][0][0], actual_outputs["forward"][0][0]
152+
)
144153
)
145154
self.assertTrue(
146-
torch.equal(expected["forward"][1][0], actual["forward"][1][0])
155+
torch.equal(
156+
expected_outputs["forward"][1][0], actual_outputs["forward"][1][0]
157+
)
147158
)
148159

149160
def test_etrecord_generation_with_manager(self):

0 commit comments

Comments
 (0)