@@ -215,6 +215,13 @@ struct PyDeviceMemorySpan {
215215 pybind11::array_t <T> py_cpu_array_;
216216};
217217
218+ struct PyNamedTensors {
219+ PyNamedTensors (std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move (named_tensors)} {
220+ }
221+
222+ std::unique_ptr<NamedTensors> named_tensors_;
223+ };
224+
218225struct PyGeneratorParams {
219226 PyGeneratorParams (const Model& model) : params_{std::make_shared<GeneratorParams>(model)} {
220227 }
@@ -238,6 +245,11 @@ struct PyGeneratorParams {
238245 refs_.emplace_back (value);
239246 }
240247
248+ void SetInputs (std::shared_ptr<PyNamedTensors> named_tensors) {
249+ params_->SetInputs (*named_tensors->named_tensors_ );
250+ named_tensors_ = named_tensors;
251+ }
252+
241253 void SetSearchOptions (const pybind11::kwargs& dict) {
242254 for (auto & entry : dict) {
243255 auto name = entry.first .cast <std::string>();
@@ -268,14 +280,8 @@ struct PyGeneratorParams {
268280 pybind11::array py_whisper_input_features_;
269281 pybind11::array py_alignment_heads_;
270282
271- std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
272- };
273-
274- struct PyNamedTensors {
275- PyNamedTensors (std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move (named_tensors)} {
276- }
277-
278- std::unique_ptr<NamedTensors> named_tensors_;
283+ std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
284+ std::shared_ptr<PyNamedTensors> named_tensors_; // Ensure the model inputs don't get garbage collected
279285};
280286
281287struct PyGenerator {
@@ -387,11 +393,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
387393 // TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic
388394 .def_readwrite (" whisper_input_features" , &PyGeneratorParams::py_whisper_input_features_)
389395 .def_readwrite (" alignment_heads" , &PyGeneratorParams::py_alignment_heads_)
390- .def (" set_inputs" , [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) {
396+ .def (" set_inputs" , [](PyGeneratorParams& generator_params, std::shared_ptr< PyNamedTensors> named_tensors) {
391397 if (!named_tensors || !named_tensors->named_tensors_ )
392398 throw std::runtime_error (" No inputs provided." );
393399
394- generator_params.params_ -> SetInputs (* named_tensors-> named_tensors_ );
400+ generator_params.SetInputs (named_tensors);
395401 })
396402 .def (" set_model_input" , &PyGeneratorParams::SetModelInput)
397403 .def (" set_search_options" , &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
@@ -456,8 +462,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
456462 generator.SetActiveAdapter (adapters, adapter_name);
457463 });
458464
459- pybind11::class_<Images>(m, " Images" )
460- .def_static (" open" , [](pybind11::args image_paths) {
465+ pybind11::class_<Images, std::shared_ptr<Images> >(m, " Images" )
466+ .def_static (" open" , [](pybind11::args image_paths) -> std::shared_ptr<Images> {
461467 if (image_paths.empty ())
462468 throw std::runtime_error (" No images provided" );
463469
@@ -470,7 +476,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
470476 image_paths_vector.push_back (image_paths_string.back ().c_str ());
471477 }
472478
473- return LoadImages (image_paths_vector);
479+ return std::shared_ptr<Images>( LoadImages (image_paths_vector) );
474480 })
475481 .def_static (" open_bytes" , [](pybind11::args image_datas) {
476482 if (image_datas.empty ())
@@ -486,10 +492,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
486492 image_raw_data[i] = ort_extensions::ImageRawData (data, data + info.size );
487493 }
488494
489- return std::make_unique <Images>(std::move (image_raw_data), image_datas.size ());
495+ return std::make_shared <Images>(std::move (image_raw_data), image_datas.size ());
490496 });
491497
492- pybind11::class_<Audios>(m, " Audios" )
498+ pybind11::class_<Audios, std::shared_ptr<Audios> >(m, " Audios" )
493499 .def_static (" open" , [](pybind11::args audio_paths) {
494500 if (audio_paths.empty ())
495501 throw std::runtime_error (" No audios provided" );
@@ -504,14 +510,14 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
504510 audio_paths_vector.push_back (audio_paths_string.back ().c_str ());
505511 }
506512
507- return LoadAudios (audio_paths_vector);
513+ return std::shared_ptr<Audios>( LoadAudios (audio_paths_vector) );
508514 });
509515
510- pybind11::class_<PyNamedTensors>(m, " NamedTensors" );
516+ pybind11::class_<PyNamedTensors, std::shared_ptr<PyNamedTensors> >(m, " NamedTensors" );
511517
512518 pybind11::class_<MultiModalProcessor, std::shared_ptr<MultiModalProcessor>>(m, " MultiModalProcessor" )
513519 .def (
514- " __call__" , [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::unique_ptr <PyNamedTensors> {
520+ " __call__" , [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::shared_ptr <PyNamedTensors> {
515521 if (kwargs.contains (" images" )) {
516522 if (processor.image_processor_ == nullptr ) {
517523 throw std::runtime_error (" Image processor is not available for this model." );
@@ -520,11 +526,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
520526 if (!prompt.has_value ()) {
521527 throw std::runtime_error (" Prompt is required for processing the image." );
522528 }
523- return std::make_unique <PyNamedTensors>(
529+ return std::make_shared <PyNamedTensors>(
524530 processor.image_processor_ ->Process (*processor.tokenizer_ , *prompt, images));
525531 } else if (kwargs.contains (" audios" )) {
526532 const Audios* audios = kwargs[" audios" ].cast <const Audios*>();
527- return std::make_unique <PyNamedTensors>(
533+ return std::make_shared <PyNamedTensors>(
528534 processor.audio_processor_ ->Process (audios));
529535 } else {
530536 throw std::runtime_error (" Nothing to process." );
0 commit comments