Skip to content

Commit 6b99e0d

Browse files
Whisper Redesigned Solution (#1229)
### Description This PR re-designs how Whisper is created and supported in ONNX Runtime GenAI. The new solution is designed to be used in conjunction with [this work](microsoft/onnxruntime#23549) in ONNX Runtime. Some of the added changes include: - Re-designed GenAI config that separates the encoder model and decoder model - Removes the encoder-decoder-init section - Creates a new encoder section - Separates session options, EP options, and model properties to be per-model instead of re-using the decoder's options for all components - Re-assigns pre-computed cross-attention KV caches as outputs to encoder model instead of inputs to decoder model - Re-designed runtime support that makes the states and steps much clearer - Creates `AudioEncoder`, `WhisperDecoder` (i.e. `TextDecoder`), and `WhisperState` as separate states - Creates `AudioFeatures` class that can be re-used for other speech models - Adds generic support for FP32 CPU, FP32 CUDA, FP16 CUDA, and any quantized versions - Removes temporary workarounds for past-present buffer sharing due to restrictions from both the exported ONNX model and ONNX Runtime - Handles models with and without the following: buffer sharing, `DecoderMaskedMultiHeadAttention`, and alignment heads - Moves setting inputs for non-LLMs from `GeneratorParams.SetInputs` to `Generator.SetInputs` (breaking change) - Allows any non-LLM to just use the `SetInputs` API for setting model inputs now. Previously, some models such as Whisper needed a combination of `SetInputs` and `AppendTokenSequences` to set all model inputs. - This does change our published examples for Phi vision (Phi-3 vision and Phi-3.5 vision) and Phi multimodal (Phi-4 multimodal) since they will need to use `Generator.SetInputs` instead of `GeneratorParams.SetInputs`. - New APIs for getting and setting inputs and outputs as well as processing batched inputs - `OgaGenerator_SetModelInput` - `OgaGenerator_SetInputs` - `OgaGenerator_GetInput` - `OgaProcessorProcessImagesAndPrompts` - `OgaProcessorProcessAudiosAndPrompts` - `OgaProcessorProcessImagesAndAudiosAndPrompts` - Unit tests for audio pre-processing and end-to-end inference ### Known Issues - The cross QK kernels do not have parity with the alternative, more-accurate approach to compute the cross QKs as a separate inference pass. Currently, it is recommended to use the alternative approach for calculating word-level timestamps. - The cross QK kernels are only supported for CUDA. - The end-to-end working example from OpenAI's implementation is still under development [here](openai/whisper@main...kunal-vaishnavi:whisper:kvaishnavi/onnxruntime-genai). Once working, a copy of those scripts will be added as a sub-folder in the Python examples. - Missing features inside ONNX Runtime GenAI include: - Adding `patience` to the `search` - Rewinding the conversation for Whisper - Calculating logprobs - Logit filtering (e.g. suppress blank tokens, suppress special tokens, applying timestamp rules, etc.) - Getting/setting logits on device memory (currently, there are CPU-GPU copies that are incurred) - Decoding with temperature fallback - Running the encoder and decoder-init steps separately (Currently both are merged into one because ONNX Runtime GenAI requires the `Run` API to return logits. Some refactoring needs to be done to run the encoder step separately.) - Note that the existing Whisper examples inside the `examples` folder can still be used. Word-level timestamps are not currently calculated in those examples. ### Motivation and Context The original implementation of Whisper was added in ONNX Runtime GenAI to create an initial foundation. This new approach is more flexible and more customizable for users. It also introduces an encoder-decoder architecture setup that can be used for other encoder-decoder models or other speech models. --------- Co-authored-by: mindest <[email protected]>
1 parent 88bb45c commit 6b99e0d

File tree

85 files changed

+132381
-1028
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+132381
-1028
lines changed

.github/workflows/linux-gpu-x64-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ jobs:
143143
docker run \
144144
--gpus all \
145145
--rm \
146-
--volume /data/ortgenai/pytorch:/data/ortgenai/pytorch \
146+
--volume /data/ortgenai/:/data/ortgenai/ \
147147
--volume $GITHUB_WORKSPACE:/ort_genai_src \
148148
-e HF_TOKEN=$HF_TOKEN \
149149
-w /ort_genai_src onnxruntimecudabuildx64 bash -c " \
@@ -170,6 +170,6 @@ jobs:
170170
docker run \
171171
--gpus all \
172172
--rm \
173-
--volume /data/ortgenai/pytorch:/data/ortgenai/pytorch \
173+
--volume /data/ortgenai/:/data/ortgenai/ \
174174
--volume $GITHUB_WORKSPACE:/ort_genai_src \
175175
-w /ort_genai_src onnxruntimecudabuildx64 bash -c "ORTGENAI_LOG_ORT_LIB=1 LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ort_genai_src/build/cuda/ /ort_genai_src/build/cuda/unit_tests"

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;5ea4b9b0683b83c1d6800eb332f37dcc76bb2e61
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;a85fa861ee5e5300f16142bd969ede0eabc61c86
1818

1919
# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
2020
llguidance;https://github.com/microsoft/llguidance.git;2d2f1de3c87e3289528affc346f734f7471216d9

examples/c/src/phi3v.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
6363
std::cout << "Generating response..." << std::endl;
6464
auto params = OgaGeneratorParams::Create(*model);
6565
params->SetSearchOption("max_length", 7680);
66-
params->SetInputs(*input_tensors);
6766

6867
auto generator = OgaGenerator::Create(*model, *params);
68+
generator->SetInputs(*input_tensors);
6969

7070
while (!generator->IsDone()) {
7171
generator->GenerateNextToken();

examples/c/src/phi4-mm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
9494
std::cout << "Generating response..." << std::endl;
9595
auto params = OgaGeneratorParams::Create(*model);
9696
params->SetSearchOption("max_length", 7680);
97-
params->SetInputs(*input_tensors);
9897

9998
auto generator = OgaGenerator::Create(*model, *params);
99+
generator->SetInputs(*input_tensors);
100100

101101
while (!generator->IsDone()) {
102102
generator->GenerateNextToken();

examples/c/src/whisper.cpp

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ void CXX_API(const char* model_path, int32_t num_beams) {
1515
auto model = OgaModel::Create(model_path);
1616
std::cout << "Creating multimodal processor..." << std::endl;
1717
auto processor = OgaMultiModalProcessor::Create(*model);
18-
std::cout << "Creating tokenizer..." << std::endl;
19-
auto tokenizer = OgaTokenizer::Create(*model);
2018

2119
while (true) {
2220
std::string audio_paths_str;
@@ -42,31 +40,24 @@ void CXX_API(const char* model_path, int32_t num_beams) {
4240
audios = OgaAudios::Load(audio_paths_c);
4341
}
4442

45-
std::cout << "Processing audio..." << std::endl;
46-
auto mel = processor->ProcessAudios(audios.get());
47-
const std::vector<const char*> prompt_tokens = {"<|startoftranscript|>", "<|en|>", "<|transcribe|>",
48-
"<|notimestamps|>"};
49-
auto input_ids = OgaSequences::Create();
43+
std::cout << "Processing inputs..." << std::endl;
5044
const size_t batch_size = audio_paths.size();
51-
for (size_t i = 0; i < batch_size; ++i) {
52-
for (const auto& token : prompt_tokens) {
53-
input_ids->Append(tokenizer->ToTokenId(token), i);
54-
}
55-
}
45+
const char* prompt_tokens = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>";
46+
const std::vector<const char*> prompts(batch_size, prompt_tokens);
47+
auto inputs = processor->ProcessAudios(prompts, audios.get());
5648

5749
std::cout << "Generating response..." << std::endl;
5850
auto params = OgaGeneratorParams::Create(*model);
59-
params->SetSearchOption("max_length", 256);
51+
params->SetSearchOption("batch_size", static_cast<double>(batch_size));
52+
params->SetSearchOption("max_length", 448);
6053
params->SetSearchOptionBool("do_sample", false);
6154
params->SetSearchOption("num_beams", num_beams);
6255
params->SetSearchOption("num_return_sequences", num_beams);
63-
params->SetInputs(*mel);
64-
params->SetInputSequences(*input_ids);
6556

6657
auto generator = OgaGenerator::Create(*model, *params);
58+
generator->SetInputs(*inputs);
6759

6860
while (!generator->IsDone()) {
69-
generator->ComputeLogits();
7061
generator->GenerateNextToken();
7162
}
7263

@@ -133,36 +124,29 @@ void C_API(const char* model_path, int32_t num_beams) {
133124
}
134125

135126
std::cout << "Processing audio..." << std::endl;
136-
OgaNamedTensors* mel;
137-
CheckResult(OgaProcessorProcessAudios(processor, audios, &mel));
138-
const std::vector<const char*> prompt_tokens = {"<|startoftranscript|>", "<|en|>", "<|transcribe|>",
139-
"<|notimestamps|>"};
140-
OgaSequences* input_ids;
141-
CheckResult(OgaCreateSequences(&input_ids));
127+
OgaNamedTensors* inputs;
142128
const size_t batch_size = audio_paths.size();
143-
for (size_t i = 0; i < batch_size; ++i) {
144-
for (const auto& token : prompt_tokens) {
145-
int32_t token_id;
146-
CheckResult(OgaTokenizerToTokenId(tokenizer, token, &token_id));
147-
CheckResult(OgaAppendTokenToSequence(token_id, input_ids, i));
148-
}
149-
}
129+
const char* prompt_tokens = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>";
130+
std::vector<const char*> prompts(batch_size, prompt_tokens);
131+
OgaStringArray* prompts_string_array;
132+
CheckResult(OgaCreateStringArrayFromStrings(prompts.data(), prompts.size(), &prompts_string_array));
133+
CheckResult(OgaProcessorProcessAudiosAndPrompts(processor, prompts_string_array, audios, &inputs));
134+
OgaDestroyStringArray(prompts_string_array);
150135

151136
std::cout << "Generating response..." << std::endl;
152137
OgaGeneratorParams* params;
153138
CheckResult(OgaCreateGeneratorParams(model, &params));
154-
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 256));
139+
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "batch_size", static_cast<double>(batch_size)));
140+
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 448));
155141
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false));
156142
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "num_beams", num_beams));
157143
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "num_return_sequences", num_beams));
158-
CheckResult(OgaGeneratorParamsSetInputs(params, mel));
159-
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_ids));
160144

161145
OgaGenerator* generator;
162146
CheckResult(OgaCreateGenerator(model, params, &generator));
147+
CheckResult(OgaGenerator_SetInputs(generator, inputs));
163148

164149
while (!OgaGenerator_IsDone(generator)) {
165-
CheckResult(OgaGenerator_ComputeLogits(generator));
166150
CheckResult(OgaGenerator_GenerateNextToken(generator));
167151
}
168152

@@ -182,8 +166,7 @@ void C_API(const char* model_path, int32_t num_beams) {
182166

183167
OgaDestroyGenerator(generator);
184168
OgaDestroyGeneratorParams(params);
185-
OgaDestroySequences(input_ids);
186-
OgaDestroyNamedTensors(mel);
169+
OgaDestroyNamedTensors(inputs);
187170
OgaDestroyAudios(audios);
188171
}
189172

@@ -203,6 +186,11 @@ int main(int argc, char** argv) {
203186
return -1;
204187
}
205188

189+
// Uncomment for debugging purposes
190+
// Oga::SetLogBool("enabled", true);
191+
// Oga::SetLogBool("model_input_values", true);
192+
// Oga::SetLogBool("model_output_values", true);
193+
206194
std::cout << "---------------" << std::endl;
207195
std::cout << "Hello, Whisper!" << std::endl;
208196
std::cout << "---------------" << std::endl;

examples/csharp/HelloPhi3V/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ void PrintUsage()
163163
Console.WriteLine("Generating response...");
164164
using GeneratorParams generatorParams = new GeneratorParams(model);
165165
generatorParams.SetSearchOption("max_length", 7680);
166-
generatorParams.SetInputs(inputTensors);
167166

168167
using var generator = new Generator(model, generatorParams);
168+
generator.SetInputs(inputTensors);
169169
var watch = System.Diagnostics.Stopwatch.StartNew();
170170
while (!generator.IsDone())
171171
{

examples/csharp/HelloPhi4MM/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ void PrintUsage()
204204
Console.WriteLine("Generating response...");
205205
using GeneratorParams generatorParams = new GeneratorParams(model);
206206
generatorParams.SetSearchOption("max_length", 7680);
207-
generatorParams.SetInputs(inputTensors);
208207

209208
using var generator = new Generator(model, generatorParams);
209+
generator.SetInputs(inputTensors);
210210
var watch = System.Diagnostics.Stopwatch.StartNew();
211211
while (!generator.IsDone())
212212
{

examples/python/model-vision.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99

1010
import onnxruntime_genai as og
11+
# og.set_log_options(enabled=True, model_input_values=True, model_output_values=True)
1112

1213
def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name):
1314
curr_path = Path(current_dir).absolute()
@@ -103,10 +104,10 @@ def run(args: argparse.Namespace):
103104

104105
print("Generating response...")
105106
params = og.GeneratorParams(model)
106-
params.set_inputs(inputs)
107107
params.set_search_options(max_length=7680)
108108

109109
generator = og.Generator(model, params)
110+
generator.set_inputs(inputs)
110111
start_time = time.time()
111112

112113
while not generator.is_done():

examples/python/phi4-mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def run(args: argparse.Namespace):
124124

125125
print("Generating response...")
126126
params = og.GeneratorParams(model)
127-
params.set_inputs(inputs)
128127
params.set_search_options(max_length=7680)
129128

130129
generator = og.Generator(model, params)
130+
generator.set_inputs(inputs)
131131
start_time = time.time()
132132

133133
while not generator.is_done():

examples/python/whisper.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import readline
88

99
import onnxruntime_genai as og
10-
10+
# og.set_log_options(enabled=True, model_input_values=True, model_output_values=True)
1111

1212
def _complete(text, state):
1313
return (glob.glob(text + "*") + [None])[state]
@@ -20,15 +20,25 @@ class Format:
2020

2121
def run(args: argparse.Namespace):
2222
print("Loading model...")
23-
model = og.Model(args.model_path)
23+
config = og.Config(args.model_path)
24+
if args.execution_provider != "follow_config":
25+
config.clear_providers()
26+
if args.execution_provider != "cpu":
27+
print(f"Setting model to {args.execution_provider}")
28+
config.append_provider(args.execution_provider)
29+
model = og.Model(config)
2430
processor = model.create_multimodal_processor()
2531
tokenizer = og.Tokenizer(model)
2632

2733
while True:
2834
readline.set_completer_delims(" \t\n;")
2935
readline.parse_and_bind("tab: complete")
3036
readline.set_completer(_complete)
31-
audio_paths = [audio_path.strip() for audio_path in input("Audio Paths (comma separated): ").split(",")]
37+
38+
if args.non_interactive:
39+
audio_paths = [args.audio]
40+
else:
41+
audio_paths = [audio_path.strip() for audio_path in input("Audio Paths (comma separated): ").split(",")]
3242
if len(audio_paths) == 0:
3343
raise ValueError("No audio provided.")
3444

@@ -39,28 +49,27 @@ def run(args: argparse.Namespace):
3949
audios = og.Audios.open(*audio_paths)
4050

4151
print("Processing audio...")
42-
mel = processor(audios=audios)
52+
batch_size = len(audio_paths)
4353
decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"]
54+
prompts = ["".join(decoder_prompt_tokens)] * batch_size
55+
inputs = processor(prompts, audios=audios)
4456

4557
params = og.GeneratorParams(model)
4658
params.set_search_options(
4759
do_sample=False,
4860
num_beams=args.num_beams,
4961
num_return_sequences=args.num_beams,
50-
max_length=256,
62+
max_length=448,
5163
)
5264

53-
batch_size = len(audio_paths)
54-
params.set_inputs(mel)
55-
params.input_ids = [[tokenizer.to_token_id(token) for token in decoder_prompt_tokens]] * batch_size
56-
5765
generator = og.Generator(model, params)
66+
generator.set_inputs(inputs)
5867

5968
while not generator.is_done():
60-
generator.compute_logits()
6169
generator.generate_next_token()
6270

6371
print()
72+
transcriptions = []
6473
for i in range(batch_size * args.num_beams):
6574
tokens = generator.get_sequence(i)
6675
transcription = processor.decode(tokens)
@@ -69,18 +78,45 @@ def run(args: argparse.Namespace):
6978
print(
7079
f" {Format.underline}batch {i // args.num_beams}, beam {i % args.num_beams}{Format.end}: {transcription}"
7180
)
81+
transcriptions.append(transcription.strip())
7282

7383
for _ in range(3):
7484
print()
7585

86+
if args.non_interactive:
87+
args.output = args.output.strip()
88+
matching = False
89+
for transcription in transcriptions:
90+
if transcription == args.output:
91+
matching = True
92+
break
93+
94+
if matching:
95+
print("One of the model's transcriptions matches the expected transcription.")
96+
return
97+
raise Exception("None of the model's transcriptions match the expected transcription.")
98+
7699

77100
if __name__ == "__main__":
78101
parser = argparse.ArgumentParser()
79102
parser.add_argument(
80103
"-m", "--model_path", type=str, required=True, help="Path to the model"
81104
)
105+
parser.add_argument(
106+
'-e', '--execution_provider', type=str, required=False, default='follow_config', choices=["cpu", "cuda", "follow_config"],
107+
help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead."
108+
)
82109
parser.add_argument(
83110
"-b", "--num_beams", type=int, default=4, help="Number of beams"
84111
)
112+
parser.add_argument(
113+
"-a", "--audio", type=str, default="", help="Path to audio file for CI testing purposes"
114+
)
115+
parser.add_argument(
116+
"-o", "--output", type=str, default="", help="Expected transcribed output for CI testing purposes"
117+
)
118+
parser.add_argument(
119+
"-ni", "--non_interactive", default=False, action="store_true", help="Non-interactive mode for CI testing purposes"
120+
)
85121
args = parser.parse_args()
86122
run(args)

0 commit comments

Comments
 (0)