@@ -181,7 +181,7 @@ def verify_with_input_ids(
181
181
original_model : ModelWrapper ,
182
182
reauthored_model : ReauthoredModelWrapper ,
183
183
input_ids : List [int ],
184
- kv_cache_max_len : int = 1024 ,
184
+ kv_cache_max_len : int = 128 ,
185
185
rtol : float = 1e-05 ,
186
186
atol : float = 1e-05 ,
187
187
):
@@ -273,6 +273,8 @@ def verify_reauthored_model(
273
273
rtol : float = 1e-05 ,
274
274
atol : float = 1e-05 ,
275
275
continue_on_failure : bool = False ,
276
+ verify_inputs : bool = True ,
277
+ verify_prompts : bool = True ,
276
278
) -> bool :
277
279
"""Verifies the reauthored model against the original model.
278
280
@@ -301,33 +303,37 @@ def verify_reauthored_model(
301
303
"""
302
304
failure_count = 0
303
305
304
- for input_ids in forward_input_ids :
305
- logging .info ("Verifying the reauthored model with input IDs: %s" , input_ids )
306
- try :
307
- verify_with_input_ids (
308
- original_model , reauthored_model , input_ids , rtol = rtol , atol = atol
306
+ if verify_inputs :
307
+ for input_ids in forward_input_ids :
308
+ logging .info (
309
+ "Verifying the reauthored model with input IDs: %s" , input_ids
309
310
)
310
- except AssertionError as e :
311
- logging .error ("*** FAILED *** verify with input IDs: %s" , input_ids )
312
- failure_count += 1
313
- if not continue_on_failure :
314
- return False
315
- else :
316
- logging .info ("*** PASSED *** verify with input IDs: %s" , input_ids )
317
-
318
- for prompts in generate_prompts :
319
- logging .info ("Verifying the reauthored model with prompts: %s" , prompts )
320
- try :
321
- verify_model_with_prompts (
322
- original_model , reauthored_model , tokenizer , prompts , max_new_tokens
323
- )
324
- except AssertionError as e :
325
- logging .error ("*** FAILED *** verify with prompts: %s" , prompts )
326
- failure_count += 1
327
- if not continue_on_failure :
328
- return False
329
- else :
330
- logging .info ("*** PASSED *** verify with prompts: %s" , prompts )
311
+ try :
312
+ verify_with_input_ids (
313
+ original_model , reauthored_model , input_ids , rtol = rtol , atol = atol
314
+ )
315
+ except AssertionError as e :
316
+ logging .error ("*** FAILED *** verify with input IDs: %s" , input_ids )
317
+ failure_count += 1
318
+ if not continue_on_failure :
319
+ return False
320
+ else :
321
+ logging .info ("*** PASSED *** verify with input IDs: %s" , input_ids )
322
+
323
+ if verify_prompts :
324
+ for prompts in generate_prompts :
325
+ logging .info ("Verifying the reauthored model with prompts: %s" , prompts )
326
+ try :
327
+ verify_model_with_prompts (
328
+ original_model , reauthored_model , tokenizer , prompts , max_new_tokens
329
+ )
330
+ except AssertionError as e :
331
+ logging .error ("*** FAILED *** verify with prompts: %s" , prompts )
332
+ failure_count += 1
333
+ if not continue_on_failure :
334
+ return False
335
+ else :
336
+ logging .info ("*** PASSED *** verify with prompts: %s" , prompts )
331
337
332
338
if failure_count == 0 :
333
339
logging .info ("*** PASSED *** verify_reauthored_model" )
0 commit comments