Skip to content

Commit 8f4d88c

Browse files
talumbaucopybara-github
authored andcommitted
Gemma-3 4B verification
PiperOrigin-RevId: 747671354
1 parent 9625f16 commit 8f4d88c

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

ai_edge_torch/generative/utilities/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
5757
)
5858
flags.DEFINE_string(
5959
'output_name_prefix',
60-
'qwen',
60+
f'{model_name}',
6161
'The prefix of the output tflite model name.',
6262
)
6363
flags.DEFINE_multi_integer(

ai_edge_torch/generative/utilities/verifier.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def verify_with_input_ids(
181181
original_model: ModelWrapper,
182182
reauthored_model: ReauthoredModelWrapper,
183183
input_ids: List[int],
184-
kv_cache_max_len: int = 1024,
184+
kv_cache_max_len: int = 128,
185185
rtol: float = 1e-05,
186186
atol: float = 1e-05,
187187
):
@@ -273,6 +273,8 @@ def verify_reauthored_model(
273273
rtol: float = 1e-05,
274274
atol: float = 1e-05,
275275
continue_on_failure: bool = False,
276+
verify_inputs: bool = True,
277+
verify_prompts: bool = True,
276278
) -> bool:
277279
"""Verifies the reauthored model against the original model.
278280
@@ -301,33 +303,37 @@ def verify_reauthored_model(
301303
"""
302304
failure_count = 0
303305

304-
for input_ids in forward_input_ids:
305-
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
306-
try:
307-
verify_with_input_ids(
308-
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
306+
if verify_inputs:
307+
for input_ids in forward_input_ids:
308+
logging.info(
309+
"Verifying the reauthored model with input IDs: %s", input_ids
309310
)
310-
except AssertionError as e:
311-
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
312-
failure_count += 1
313-
if not continue_on_failure:
314-
return False
315-
else:
316-
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
317-
318-
for prompts in generate_prompts:
319-
logging.info("Verifying the reauthored model with prompts: %s", prompts)
320-
try:
321-
verify_model_with_prompts(
322-
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
323-
)
324-
except AssertionError as e:
325-
logging.error("*** FAILED *** verify with prompts: %s", prompts)
326-
failure_count += 1
327-
if not continue_on_failure:
328-
return False
329-
else:
330-
logging.info("*** PASSED *** verify with prompts: %s", prompts)
311+
try:
312+
verify_with_input_ids(
313+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
314+
)
315+
except AssertionError as e:
316+
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
317+
failure_count += 1
318+
if not continue_on_failure:
319+
return False
320+
else:
321+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
322+
323+
if verify_prompts:
324+
for prompts in generate_prompts:
325+
logging.info("Verifying the reauthored model with prompts: %s", prompts)
326+
try:
327+
verify_model_with_prompts(
328+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
329+
)
330+
except AssertionError as e:
331+
logging.error("*** FAILED *** verify with prompts: %s", prompts)
332+
failure_count += 1
333+
if not continue_on_failure:
334+
return False
335+
else:
336+
logging.info("*** PASSED *** verify with prompts: %s", prompts)
331337

332338
if failure_count == 0:
333339
logging.info("*** PASSED *** verify_reauthored_model")

0 commit comments

Comments
 (0)