diff --git a/pywhispercpp/constants.py b/pywhispercpp/constants.py index 5809863..903d367 100644 --- a/pywhispercpp/constants.py +++ b/pywhispercpp/constants.py @@ -139,7 +139,7 @@ }, 'max_len': { 'type': int, - 'description': "max segment length in characters", + 'description': "max segment length in characters, note: token_timestamps needs to be set to True for this to work", 'options': None, 'default': 0 }, diff --git a/pywhispercpp/model.py b/pywhispercpp/model.py index 7c81ea9..9f4281e 100644 --- a/pywhispercpp/model.py +++ b/pywhispercpp/model.py @@ -70,6 +70,10 @@ def __init__(self, models_dir: str = None, params_sampling_strategy: int = 0, redirect_whispercpp_logs_to: Union[bool, TextIO, str, None] = False, + use_openvino: bool = False, + openvino_model_path: str = None, + openvino_device: str = 'CPU', + openvino_cache_dir: str = None, **params): """ :param model: The name of the model, one of the [AVAILABLE_MODELS](/pywhispercpp/#pywhispercpp.constants.AVAILABLE_MODELS), @@ -78,6 +82,10 @@ def __init__(self, exist, default to [MODELS_DIR](/pywhispercpp/#pywhispercpp.constants.MODELS_DIR) :param params_sampling_strategy: 0 -> GREEDY, else BEAM_SEARCH :param redirect_whispercpp_logs_to: where to redirect the whisper.cpp logs, default to False (no redirection), accepts str file path, sys.stdout, sys.stderr, or use None to redirect to devnull + :param use_openvino: whether to use OpenVINO or not + :param openvino_model_path: path to the OpenVINO model + :param openvino_device: OpenVINO device, default to CPU + :param openvino_cache_dir: OpenVINO cache directory :param params: keyword arguments for different whisper.cpp parameters, see [PARAMS_SCHEMA](/pywhispercpp/#pywhispercpp.constants.PARAMS_SCHEMA) """ @@ -90,8 +98,13 @@ def __init__(self, pw.whisper_sampling_strategy.WHISPER_SAMPLING_BEAM_SEARCH self._params = pw.whisper_full_default_params(self._sampling_strategy) # assign params + self.params = params self._set_params(params) self.redirect_whispercpp_logs_to = redirect_whispercpp_logs_to + self.use_openvino = use_openvino + self.openvino_model_path = openvino_model_path + self.openvino_device = openvino_device + self.openvino_cache_dir = openvino_cache_dir # init the model self._init_model() @@ -228,6 +241,10 @@ def _init_model(self) -> None: logger.info("Initializing the model ...") with utils.redirect_stderr(to=self.redirect_whispercpp_logs_to): self._ctx = pw.whisper_init_from_file(self.model_path) + if self.use_openvino: + pw.whisper_ctx_init_openvino_encoder(self._ctx, self.openvino_model_path, self.openvino_device, self.openvino_cache_dir) + + def _set_params(self, kwargs: dict) -> None: """ diff --git a/src/main.cpp b/src/main.cpp index 91b7670..e43490e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -276,6 +276,12 @@ float whisper_full_get_token_p_wrapper(struct whisper_context_wrapper * ctx, int return whisper_full_get_token_p(ctx->ptr, i_segment, i_token); } +int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * ctx, const char * model_path, + const char * device, + const char * cache_dir){ + return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir); +} + class WhisperFullParamsWrapper : public whisper_full_params { std::string initial_prompt_str; std::string suppress_regex_str; @@ -656,6 +662,9 @@ PYBIND11_MODULE(_pywhispercpp, m) { m.def("whisper_full_get_token_p", &whisper_full_get_token_p_wrapper, "Get the probability of the specified token in the specified segment."); + m.def("whisper_ctx_init_openvino_encoder", &whisper_ctx_init_openvino_encoder_wrapper, "Given a context, enable use of OpenVINO for encode inference."); + + //////////////////////////////////////////////////////////////////////////// m.def("whisper_bench_memcpy", &whisper_bench_memcpy, "Temporary helpers needed for exposing ggml interface");