Skip to content

Commit 636a95e

Browse files
authored
Hold onto named tensors to ensure they don't get garbage collected in Python (#1174)
1 parent ee1fadd commit 636a95e

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

src/python/python.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ struct PyDeviceMemorySpan {
215215
pybind11::array_t<T> py_cpu_array_;
216216
};
217217

218+
struct PyNamedTensors {
219+
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
220+
}
221+
222+
std::unique_ptr<NamedTensors> named_tensors_;
223+
};
224+
218225
struct PyGeneratorParams {
219226
PyGeneratorParams(const Model& model) : params_{std::make_shared<GeneratorParams>(model)} {
220227
}
@@ -238,6 +245,11 @@ struct PyGeneratorParams {
238245
refs_.emplace_back(value);
239246
}
240247

248+
void SetInputs(std::shared_ptr<PyNamedTensors> named_tensors) {
249+
params_->SetInputs(*named_tensors->named_tensors_);
250+
named_tensors_ = named_tensors;
251+
}
252+
241253
void SetSearchOptions(const pybind11::kwargs& dict) {
242254
for (auto& entry : dict) {
243255
auto name = entry.first.cast<std::string>();
@@ -268,14 +280,8 @@ struct PyGeneratorParams {
268280
pybind11::array py_whisper_input_features_;
269281
pybind11::array py_alignment_heads_;
270282

271-
std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
272-
};
273-
274-
struct PyNamedTensors {
275-
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
276-
}
277-
278-
std::unique_ptr<NamedTensors> named_tensors_;
283+
std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
284+
std::shared_ptr<PyNamedTensors> named_tensors_; // Ensure the model inputs don't get garbage collected
279285
};
280286

281287
struct PyGenerator {
@@ -387,11 +393,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
387393
// TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic
388394
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
389395
.def_readwrite("alignment_heads", &PyGeneratorParams::py_alignment_heads_)
390-
.def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) {
396+
.def("set_inputs", [](PyGeneratorParams& generator_params, std::shared_ptr<PyNamedTensors> named_tensors) {
391397
if (!named_tensors || !named_tensors->named_tensors_)
392398
throw std::runtime_error("No inputs provided.");
393399

394-
generator_params.params_->SetInputs(*named_tensors->named_tensors_);
400+
generator_params.SetInputs(named_tensors);
395401
})
396402
.def("set_model_input", &PyGeneratorParams::SetModelInput)
397403
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
@@ -456,8 +462,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
456462
generator.SetActiveAdapter(adapters, adapter_name);
457463
});
458464

459-
pybind11::class_<Images>(m, "Images")
460-
.def_static("open", [](pybind11::args image_paths) {
465+
pybind11::class_<Images, std::shared_ptr<Images>>(m, "Images")
466+
.def_static("open", [](pybind11::args image_paths) -> std::shared_ptr<Images> {
461467
if (image_paths.empty())
462468
throw std::runtime_error("No images provided");
463469

@@ -470,7 +476,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
470476
image_paths_vector.push_back(image_paths_string.back().c_str());
471477
}
472478

473-
return LoadImages(image_paths_vector);
479+
return std::shared_ptr<Images>(LoadImages(image_paths_vector));
474480
})
475481
.def_static("open_bytes", [](pybind11::args image_datas) {
476482
if (image_datas.empty())
@@ -486,10 +492,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
486492
image_raw_data[i] = ort_extensions::ImageRawData(data, data + info.size);
487493
}
488494

489-
return std::make_unique<Images>(std::move(image_raw_data), image_datas.size());
495+
return std::make_shared<Images>(std::move(image_raw_data), image_datas.size());
490496
});
491497

492-
pybind11::class_<Audios>(m, "Audios")
498+
pybind11::class_<Audios, std::shared_ptr<Audios>>(m, "Audios")
493499
.def_static("open", [](pybind11::args audio_paths) {
494500
if (audio_paths.empty())
495501
throw std::runtime_error("No audios provided");
@@ -504,14 +510,14 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
504510
audio_paths_vector.push_back(audio_paths_string.back().c_str());
505511
}
506512

507-
return LoadAudios(audio_paths_vector);
513+
return std::shared_ptr<Audios>(LoadAudios(audio_paths_vector));
508514
});
509515

510-
pybind11::class_<PyNamedTensors>(m, "NamedTensors");
516+
pybind11::class_<PyNamedTensors, std::shared_ptr<PyNamedTensors>>(m, "NamedTensors");
511517

512518
pybind11::class_<MultiModalProcessor, std::shared_ptr<MultiModalProcessor>>(m, "MultiModalProcessor")
513519
.def(
514-
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::unique_ptr<PyNamedTensors> {
520+
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::shared_ptr<PyNamedTensors> {
515521
if (kwargs.contains("images")) {
516522
if (processor.image_processor_ == nullptr) {
517523
throw std::runtime_error("Image processor is not available for this model.");
@@ -520,11 +526,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
520526
if (!prompt.has_value()) {
521527
throw std::runtime_error("Prompt is required for processing the image.");
522528
}
523-
return std::make_unique<PyNamedTensors>(
529+
return std::make_shared<PyNamedTensors>(
524530
processor.image_processor_->Process(*processor.tokenizer_, *prompt, images));
525531
} else if (kwargs.contains("audios")) {
526532
const Audios* audios = kwargs["audios"].cast<const Audios*>();
527-
return std::make_unique<PyNamedTensors>(
533+
return std::make_shared<PyNamedTensors>(
528534
processor.audio_processor_->Process(audios));
529535
} else {
530536
throw std::runtime_error("Nothing to process.");

0 commit comments

Comments
 (0)