@@ -644,9 +644,12 @@ def _export_adapter(adapter, adapter_file_name):
644644 reason = "ONNX is not available on ARM64" ,
645645)
646646@pytest .mark .parametrize ("extra_inputs" , [("num_logits_to_keep" , True ), ("onnx::Neg_67" , True ), ("abcde" , False )])
647- def test_preset_extra_inputs (device , phi2_for , extra_inputs ):
648- def _prepare_model (extra_inputs_model_path ):
647+ def test_preset_extra_inputs (test_data_path , device , phi2_for , extra_inputs ):
648+ def _prepare_model (test_data_path ):
649649 phi2_model_path = phi2_for (device )
650+ relative_model_path = "preset_extra_inputs"
651+ extra_inputs_model_path = os .fspath (Path (test_data_path ) / relative_model_path )
652+
650653 shutil .copytree (phi2_model_path , extra_inputs_model_path , dirs_exist_ok = True )
651654
652655 # Create the model with the extra inputs
@@ -681,29 +684,28 @@ def _prepare_model(extra_inputs_model_path):
681684 location = "model.data" ,
682685 )
683686
684- return valid
685-
686- with tempfile .TemporaryDirectory () as model_path :
687- valid_model = _prepare_model (model_path )
688- model = og .Model (model_path )
689- tokenizer = og .Tokenizer (model )
690- prompts = [
691- "This is a test." ,
692- "Rats are awesome pets!" ,
693- "The quick brown fox jumps over the lazy dog." ,
694- ]
687+ return extra_inputs_model_path , valid
695688
696- params = og .GeneratorParams (model )
697- params .set_search_options (max_length = 20 , batch_size = len (prompts ))
689+ model_path , valid_model = _prepare_model (test_data_path )
690+ model = og .Model (model_path )
691+ tokenizer = og .Tokenizer (model )
692+ prompts = [
693+ "This is a test." ,
694+ "Rats are awesome pets!" ,
695+ "The quick brown fox jumps over the lazy dog." ,
696+ ]
698697
699- generator = og .Generator (model , params )
700- if not valid_model :
701- with pytest .raises (og .OrtException ) as exc_info :
702- generator .append_tokens (tokenizer .encode_batch (prompts ))
698+ params = og .GeneratorParams (model )
699+ params .set_search_options (max_length = 20 , batch_size = len (prompts ))
703700
704- assert f"Missing Input: { extra_inputs [0 ]} " in str (exc_info .value )
705- else :
701+ generator = og .Generator (model , params )
702+ if not valid_model :
703+ with pytest .raises (og .OrtException ) as exc_info :
706704 generator .append_tokens (tokenizer .encode_batch (prompts ))
707705
708- while not generator .is_done ():
709- generator .generate_next_token ()
706+ assert f"Missing Input: { extra_inputs [0 ]} " in str (exc_info .value )
707+ else :
708+ generator .append_tokens (tokenizer .encode_batch (prompts ))
709+
710+ while not generator .is_done ():
711+ generator .generate_next_token ()
0 commit comments