Skip to content

Commit d0ec019

Browse files
committed
Let directory live
1 parent 5993ded commit d0ec019

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

test/python/test_onnxruntime_genai_api.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,12 @@ def _export_adapter(adapter, adapter_file_name):
644644
reason="ONNX is not available on ARM64",
645645
)
646646
@pytest.mark.parametrize("extra_inputs", [("num_logits_to_keep", True), ("onnx::Neg_67", True), ("abcde", False)])
647-
def test_preset_extra_inputs(device, phi2_for, extra_inputs):
648-
def _prepare_model(extra_inputs_model_path):
647+
def test_preset_extra_inputs(test_data_path, device, phi2_for, extra_inputs):
648+
def _prepare_model(test_data_path):
649649
phi2_model_path = phi2_for(device)
650+
relative_model_path = "preset_extra_inputs"
651+
extra_inputs_model_path = os.fspath(Path(test_data_path) / relative_model_path)
652+
650653
shutil.copytree(phi2_model_path, extra_inputs_model_path, dirs_exist_ok=True)
651654

652655
# Create the model with the extra inputs
@@ -681,29 +684,28 @@ def _prepare_model(extra_inputs_model_path):
681684
location="model.data",
682685
)
683686

684-
return valid
685-
686-
with tempfile.TemporaryDirectory() as model_path:
687-
valid_model = _prepare_model(model_path)
688-
model = og.Model(model_path)
689-
tokenizer = og.Tokenizer(model)
690-
prompts = [
691-
"This is a test.",
692-
"Rats are awesome pets!",
693-
"The quick brown fox jumps over the lazy dog.",
694-
]
687+
return extra_inputs_model_path, valid
695688

696-
params = og.GeneratorParams(model)
697-
params.set_search_options(max_length=20, batch_size=len(prompts))
689+
model_path, valid_model = _prepare_model(test_data_path)
690+
model = og.Model(model_path)
691+
tokenizer = og.Tokenizer(model)
692+
prompts = [
693+
"This is a test.",
694+
"Rats are awesome pets!",
695+
"The quick brown fox jumps over the lazy dog.",
696+
]
698697

699-
generator = og.Generator(model, params)
700-
if not valid_model:
701-
with pytest.raises(og.OrtException) as exc_info:
702-
generator.append_tokens(tokenizer.encode_batch(prompts))
698+
params = og.GeneratorParams(model)
699+
params.set_search_options(max_length=20, batch_size=len(prompts))
703700

704-
assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
705-
else:
701+
generator = og.Generator(model, params)
702+
if not valid_model:
703+
with pytest.raises(og.OrtException) as exc_info:
706704
generator.append_tokens(tokenizer.encode_batch(prompts))
707705

708-
while not generator.is_done():
709-
generator.generate_next_token()
706+
assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
707+
else:
708+
generator.append_tokens(tokenizer.encode_batch(prompts))
709+
710+
while not generator.is_done():
711+
generator.generate_next_token()

0 commit comments

Comments
 (0)