You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
This PR re-designs how Whisper is created and supported in ONNX Runtime
GenAI. The new solution is designed to be used in conjunction with [this
work](microsoft/onnxruntime#23549) in ONNX
Runtime.
Some of the added changes include:
- Re-designed GenAI config that separates the encoder model and decoder
model
- Removes the encoder-decoder-init section
- Creates a new encoder section
- Separates session options, EP options, and model properties to be
per-model instead of re-using the decoder's options for all components
- Re-assigns pre-computed cross-attention KV caches as outputs to
encoder model instead of inputs to decoder model
- Re-designed runtime support that makes the states and steps much
clearer
- Creates `AudioEncoder`, `WhisperDecoder` (i.e. `TextDecoder`), and
`WhisperState` as separate states
- Creates `AudioFeatures` class that can be re-used for other speech
models
- Adds generic support for FP32 CPU, FP32 CUDA, FP16 CUDA, and any
quantized versions
- Removes temporary workarounds for past-present buffer sharing due to
restrictions from both the exported ONNX model and ONNX Runtime
- Handles models with and without the following: buffer sharing,
`DecoderMaskedMultiHeadAttention`, and alignment heads
- Moves setting inputs for non-LLMs from `GeneratorParams.SetInputs` to
`Generator.SetInputs` (breaking change)
- Allows any non-LLM to just use the `SetInputs` API for setting model
inputs now. Previously, some models such as Whisper needed a combination
of `SetInputs` and `AppendTokenSequences` to set all model inputs.
- This does change our published examples for Phi vision (Phi-3 vision
and Phi-3.5 vision) and Phi multimodal (Phi-4 multimodal) since they
will need to use `Generator.SetInputs` instead of
`GeneratorParams.SetInputs`.
- New APIs for getting and setting inputs and outputs as well as
processing batched inputs
- `OgaGenerator_SetModelInput`
- `OgaGenerator_SetInputs`
- `OgaGenerator_GetInput`
- `OgaProcessorProcessImagesAndPrompts`
- `OgaProcessorProcessAudiosAndPrompts`
- `OgaProcessorProcessImagesAndAudiosAndPrompts`
- Unit tests for audio pre-processing and end-to-end inference
### Known Issues
- The cross QK kernels do not have parity with the alternative,
more-accurate approach to compute the cross QKs as a separate inference
pass. Currently, it is recommended to use the alternative approach for
calculating word-level timestamps.
- The cross QK kernels are only supported for CUDA.
- The end-to-end working example from OpenAI's implementation is still
under development
[here](openai/whisper@main...kunal-vaishnavi:whisper:kvaishnavi/onnxruntime-genai).
Once working, a copy of those scripts will be added as a sub-folder in
the Python examples.
- Missing features inside ONNX Runtime GenAI include:
- Adding `patience` to the `search`
- Rewinding the conversation for Whisper
- Calculating logprobs
- Logit filtering (e.g. suppress blank tokens, suppress special tokens,
applying timestamp rules, etc.)
- Getting/setting logits on device memory (currently, there are CPU-GPU
copies that are incurred)
- Decoding with temperature fallback
- Running the encoder and decoder-init steps separately (Currently both
are merged into one because ONNX Runtime GenAI requires the `Run` API to
return logits. Some refactoring needs to be done to run the encoder step
separately.)
- Note that the existing Whisper examples inside the `examples` folder
can still be used. Word-level timestamps are not currently calculated in
those examples.
### Motivation and Context
The original implementation of Whisper was added in ONNX Runtime GenAI
to create an initial foundation. This new approach is more flexible and
more customizable for users. It also introduces an encoder-decoder
architecture setup that can be used for other encoder-decoder models or
other speech models.
---------
Co-authored-by: mindest <[email protected]>
help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead."
108
+
)
82
109
parser.add_argument(
83
110
"-b", "--num_beams", type=int, default=4, help="Number of beams"
84
111
)
112
+
parser.add_argument(
113
+
"-a", "--audio", type=str, default="", help="Path to audio file for CI testing purposes"
114
+
)
115
+
parser.add_argument(
116
+
"-o", "--output", type=str, default="", help="Expected transcribed output for CI testing purposes"
117
+
)
118
+
parser.add_argument(
119
+
"-ni", "--non_interactive", default=False, action="store_true", help="Non-interactive mode for CI testing purposes"
0 commit comments