Skip to content

Commit 7030ec9

Browse files
committed
fix(print_progress): print progress callback #97
1 parent d6f73bb commit 7030ec9

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/main.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
********************************************************************************
33
* @file main.cpp
4-
* @author [abdeladim-s](https://github.com/abdeladim-s)
4+
* @author [absadiki](https://github.com/absadiki)
55
* @date 2023
66
* @brief Python bindings for [whisper.cpp](https://github.com/ggerganov/whisper.cpp) using Pybind11
77
*
@@ -280,12 +280,26 @@ class WhisperFullParamsWrapper : public whisper_full_params {
280280
std::string initial_prompt_str;
281281
std::string suppress_regex_str;
282282
public:
283+
py::function py_progress_callback;
283284
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
284285
: whisper_full_params(params),
285286
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
286287
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
287288
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
288289
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
290+
// progress callback
291+
progress_callback_user_data = this;
292+
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
293+
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
294+
if(self && self->print_progress){
295+
if (self->py_progress_callback) {
296+
self->py_progress_callback(progress); // Call Python callback
297+
}
298+
else {
299+
fprintf(stderr, "Progress: %3d%%\n", progress);
300+
} // Default message
301+
}
302+
} ;
289303
}
290304

291305
WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
@@ -557,6 +571,7 @@ PYBIND11_MODULE(_pywhispercpp, m) {
557571
.def_readwrite("single_segment", &WhisperFullParamsWrapper::single_segment)
558572
.def_readwrite("print_special", &WhisperFullParamsWrapper::print_special)
559573
.def_readwrite("print_progress", &WhisperFullParamsWrapper::print_progress)
574+
.def_readwrite("progress_callback", &WhisperFullParamsWrapper::py_progress_callback)
560575
.def_readwrite("print_realtime", &WhisperFullParamsWrapper::print_realtime)
561576
.def_readwrite("print_timestamps", &WhisperFullParamsWrapper::print_timestamps)
562577
.def_readwrite("token_timestamps", &WhisperFullParamsWrapper::token_timestamps)

0 commit comments

Comments
 (0)