|
1 | 1 | /**
|
2 | 2 | ********************************************************************************
|
3 | 3 | * @file main.cpp
|
4 |
| - * @author [abdeladim-s](https://github.com/abdeladim-s) |
| 4 | + * @author [absadiki](https://github.com/absadiki) |
5 | 5 | * @date 2023
|
6 | 6 | * @brief Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) using Pybind11
|
7 | 7 | *
|
@@ -280,12 +280,26 @@ class WhisperFullParamsWrapper : public whisper_full_params {
|
280 | 280 | std::string initial_prompt_str;
|
281 | 281 | std::string suppress_regex_str;
|
282 | 282 | public:
|
| 283 | + py::function py_progress_callback; |
283 | 284 | WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
|
284 | 285 | : whisper_full_params(params),
|
285 | 286 | initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
|
286 | 287 | suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
|
287 | 288 | initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
|
288 | 289 | suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
|
| 290 | + // progress callback |
| 291 | + progress_callback_user_data = this; |
| 292 | + progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { |
| 293 | + auto* self = static_cast<WhisperFullParamsWrapper*>(user_data); |
| 294 | + if(self && self->print_progress){ |
| 295 | + if (self->py_progress_callback) { |
| 296 | + self->py_progress_callback(progress); // Call Python callback |
| 297 | + } |
| 298 | + else { |
| 299 | + fprintf(stderr, "Progress: %3d%%\n", progress); |
| 300 | + } // Default message |
| 301 | + } |
| 302 | + } ; |
289 | 303 | }
|
290 | 304 |
|
291 | 305 | WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
|
@@ -557,6 +571,7 @@ PYBIND11_MODULE(_pywhispercpp, m) {
|
557 | 571 | .def_readwrite("single_segment", &WhisperFullParamsWrapper::single_segment)
|
558 | 572 | .def_readwrite("print_special", &WhisperFullParamsWrapper::print_special)
|
559 | 573 | .def_readwrite("print_progress", &WhisperFullParamsWrapper::print_progress)
|
| 574 | + .def_readwrite("progress_callback", &WhisperFullParamsWrapper::py_progress_callback) |
560 | 575 | .def_readwrite("print_realtime", &WhisperFullParamsWrapper::print_realtime)
|
561 | 576 | .def_readwrite("print_timestamps", &WhisperFullParamsWrapper::print_timestamps)
|
562 | 577 | .def_readwrite("token_timestamps", &WhisperFullParamsWrapper::token_timestamps)
|
|
0 commit comments