Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
60ff5f0
Add initial script for model validation tool
ayissi-msft Sep 23, 2024
4d48841
Removing calculate_perplexity logic
ayissi-msft Sep 24, 2024
2df8e01
Update validation script and config files based on feedback
ayissi-msft Sep 25, 2024
6481452
created a json object, loop, and updated the validate_model method
ayissi-msft Sep 27, 2024
15eb7d8
removed json object, added table to be printed
ayissi-msft Sep 30, 2024
e1cb1b7
fix return statement for validate_model
ayissi-msft Sep 30, 2024
6063694
updated validation_tool and add exception messages
ayissi-msft Sep 30, 2024
b2b27fb
fixing the config file
ayissi-msft Oct 2, 2024
d6e9aed
added the precision and executive provider to the config
ayissi-msft Oct 3, 2024
8f291ba
adding chat template
ayissi-msft Oct 4, 2024
9b94bc7
Add the README.md
ayissi-msft Oct 7, 2024
d25d69f
reformatting the config file
ayissi-msft Oct 7, 2024
2c5bad9
updating the chat templates
ayissi-msft Oct 8, 2024
e8756ed
updated README
ayissi-msft Oct 8, 2024
3f707c4
updated chat templates + README.md
ayissi-msft Oct 8, 2024
c4f503e
updated README.md and requirements.txt file
ayissi-msft Oct 9, 2024
876d22f
updated requirements.txt, config file, and validation tool script
ayissi-msft Oct 9, 2024
693bf49
added default chat templates
ayissi-msft Oct 10, 2024
3ff50bb
updated default chat templates
ayissi-msft Oct 15, 2024
71556ed
removing values
ayissi-msft Oct 16, 2024
be6ae2c
fixing ep typo
ayissi-msft Oct 16, 2024
b7e4041
update README.md
ayissi-msft Oct 16, 2024
e47b9a4
added additional resources for chat templates
ayissi-msft Oct 16, 2024
6279a1c
removing unnecessary comments
ayissi-msft Oct 16, 2024
856f5eb
adding the correct chat template for llama 2
ayissi-msft Oct 16, 2024
7bb1996
added all the supported model family's chat_templates
ayissi-msft Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tools/python/model_validation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# ONNX Runtime GenAI Model Validation Tutorial

## Background
Gen-AI serves as an API framework designed to operate generative models via ONNX Runtime. With the expansion in the variety of models, there's a growing need for a tool chain that can effectively assess the compatibility between Gen-AI and different model variants.

## Setup and Requirements
Clone this repository and navigate to the `tools/python/model_validation folder`.

```bash
git clone https://github.com/microsoft/onnxruntime-genai.git
cd tools/python/model_validation
pip install -r requirements.txt
```

Ensure you log into HuggingFace.

More about the HuggingFace CLI [here](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli)
```bash
huggingface-cli login
```

Within the model_validation directory, you'll locate the script named validation_tool.py, alongside the validation_config.json configuration file and a README.md document.

### Current Supported Model Architectures
* Gemma
* Llama
* Mistral
* Phi (language + vision)
* Qwen

## Usage
### Steps to Configure
1. Input the name of the Hugging Face model you're using into the validation_config.json file. You can find a list of supported models via this link: (https://huggingface.co)

* Also, add the chat_template associated with your model. This is located in the tokenizer_config.json file on the Hugging Face website. Make sure to replace ``` message['content'] ``` with ``` {input} ```.
* Discover more about chat templates [here](https://huggingface.co/docs/transformers/main/chat_templating)


2. Specify the path for the output folder you prefer, along with the precision and execution provider details.

After the model has been created, it will be located in the path_to_output_folder/{model_name} directory. This directory will contain both the ONNX model data and the tokenizer.

### Run the Model Validation Script
```bash
python validation_tool.py -j validation_config.json
```

### Output
Once the tool has been executed successfully, it generates a file named model_validation.csv. This file contains the Model Name, the validation tool's completion status, and details of any exceptions or failures encountered by the model during the validation process.
13 changes: 13 additions & 0 deletions tools/python/model_validation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
coloredlogs
flatbuffers
numpy<2
packaging
protobuf>=5.28.2
sympy
pytest
onnx
transformers
huggingface_hub[cli]
onnxruntime-genai
sentencepiece
pandas
25 changes: 25 additions & 0 deletions tools/python/model_validation/validation_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"models": [
{
"name": "meta-llama/Llama-2-7b-chat-hf",
"chat_template": "<s>[INST]<<SYS>>\n{input}<</SYS>>[INST]"
}
],
"inputs": [
"Provide a detailed analysis as to why the University of Southern California is better than the University of California, Los Angeles."
],
"output_directory": "",
"cache_directory": "",
"precision": "",
"execution_provider": "",
"verbose": false,
"search_options": {
"max_length": 512,
"min_length": 0,
"do_sample": false,
"top_p": 0.0,
"top_k": 1,
"temperature": 1.0,
"repetition_penalty": 1.0
}
}
122 changes: 122 additions & 0 deletions tools/python/model_validation/validation_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import onnxruntime_genai as og
import argparse
from onnxruntime_genai.models.builder import create_model
import json
import os
import pandas as pd

def create_table(output):
df = pd.DataFrame(output, columns=['Model Name', 'Validation Completed', 'Exceptions / Failures'])
return df

def validate_model(args, model_dict, model_dir):
if args["verbose"]: print("Loading model...")

model = og.Model(f'{model_dir}')

if args["verbose"]: print("Model loaded")
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()

if args["verbose"]: print("Tokenizer created")
if args["verbose"]: print()

chat_template = model_dict["chat_template"]

search_options = args["search_options"]

for text in args["inputs"]:

complete_text = ''

prompt = f'{chat_template.format(input=text)}'

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
params.input_ids = input_tokens

generator = og.Generator(model, params)
if args["verbose"]: print("Generator created")

if args["verbose"]: print("Running generation loop ...")

print()
print("Output: ", end='', flush=True)

generation_successful = True

try:
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]

value_to_save = tokenizer_stream.decode(new_token)

complete_text += value_to_save

print(tokenizer_stream.decode(new_token), end='', flush=True)

except KeyboardInterrupt:
print(" --control+c pressed, aborting generation--")
generation_successful = False
except Exception as e:
print(f"An error occurred: {e}")
generation_successful = False

with open(f'{model_dir}/output.txt', 'a', encoding='utf-8') as file:
file.write(complete_text)

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

return generation_successful

if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
parser.add_argument('-j', '--json', type=str, required=True, help='Path to the JSON file containing the arguments')
args = parser.parse_args()

with open(args.json, 'r') as file:
args = json.load(file)

os.makedirs(args["output_directory"], exist_ok=True)
os.makedirs(args["cache_directory"], exist_ok=True)

output = []

validation_complete = False
e = None
exception = False

for model_dict in args["models"]:

print(f"We are validating {model_dict['name']}")
adjusted_model = model_dict["name"].replace("/", "_")

output_path = args["output_directory"] + f'/{adjusted_model}'
cache_path = args["cache_directory"] + f'/{adjusted_model}'

try:
create_model(model_dict["name"], '', output_path, args["precision"], args["execution_provider"], cache_path)
except Exception as e:
print(f'Failure after create model {e}')
output.append([model_dict["name"], validation_complete, e])
exception = True
continue
try:
validation_complete = validate_model(args, model_dict, output_path)
except Exception as e:
print(f'Failure after validation model {e}')
exception = True
output.append([model_dict["name"], validation_complete, e])

if not exception:
output.append([model_dict["name"], validation_complete, e])

df = create_table(output)

df.to_csv("validation_summary.csv")
Loading