23
23
#include < executorch/extension/data_loader/buffer_data_loader.h>
24
24
#include < executorch/extension/data_loader/mmap_data_loader.h>
25
25
#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26
- #include < executorch/extension/module/bundled_module.h>
27
26
#include < executorch/extension/threadpool/threadpool.h>
28
27
#include < executorch/runtime/backend/interface.h>
29
28
#include < executorch/runtime/core/data_loader.h>
@@ -97,7 +96,6 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
97
96
using ::executorch::extension::BufferDataLoader;
98
97
using ::executorch::extension::MallocMemoryAllocator;
99
98
using ::executorch::extension::MmapDataLoader;
100
- using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
101
99
using ::executorch::runtime::ArrayRef;
102
100
using ::executorch::runtime::DataLoader;
103
101
using ::executorch::runtime::Error;
@@ -442,54 +440,13 @@ inline std::unique_ptr<Module> load_module_from_file(
442
440
program_verification);
443
441
}
444
442
445
- inline py::list get_outputs_as_py_list (
446
- const std::vector<EValue>& outputs,
447
- bool clone_outputs = true ) {
448
- const auto outputs_size = outputs.size ();
449
- py::list list (outputs_size);
450
- for (size_t i = 0 ; i < outputs_size; ++i) {
451
- auto & v = outputs[i];
452
- if (Tag::None == v.tag ) {
453
- list[i] = py::none ();
454
- } else if (Tag::Int == v.tag ) {
455
- list[i] = py::cast (v.toInt ());
456
- } else if (Tag::Double == v.tag ) {
457
- list[i] = py::cast (v.toDouble ());
458
- } else if (Tag::Bool == v.tag ) {
459
- list[i] = py::cast (v.toBool ());
460
- } else if (Tag::String == v.tag ) {
461
- list[i] = py::cast (std::string (v.toString ().data ()));
462
- } else if (Tag::Tensor == v.tag ) {
463
- #ifdef USE_ATEN_LIB
464
- // Clone so the outputs in python do not share a lifetime with the
465
- // module object
466
- if (clone_outputs) {
467
- list[i] = py::cast (v.toTensor ().clone ());
468
- } else {
469
- list[i] = py::cast (v.toTensor ());
470
- }
471
- #else
472
- if (clone_outputs) {
473
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
474
- } else {
475
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
476
- }
477
- #endif
478
- } else {
479
- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
480
- }
481
- }
482
- return list;
483
- }
484
-
485
443
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
486
444
487
- struct PyBundledModule : public BundledModule {
445
+ struct PyBundledModule final {
488
446
explicit PyBundledModule (
489
447
const py::bytes& buffer,
490
448
uint32_t bundled_input_pool_size)
491
- : BundledModule(buffer.cast<std::string_view>().data()),
492
- bundled_program_ptr_(buffer),
449
+ : bundled_program_ptr_(buffer),
493
450
program_ptr_(static_cast <const void *>(
494
451
bundled_program_flatbuffer::GetBundledProgram (
495
452
get_bundled_program_ptr ())
@@ -518,33 +475,6 @@ struct PyBundledModule : public BundledModule {
518
475
return program_len_;
519
476
}
520
477
521
- py::list verify_result_with_bundled_expected_output (
522
- const std::string& method_name,
523
- size_t testset_idx,
524
- double rtol = 1e-5 ,
525
- double atol = 1e-8 ) {
526
- // Execute the method
527
- auto result = BundledModule::execute (method_name, testset_idx);
528
- if (!result.ok ()) {
529
- THROW_IF_ERROR (
530
- result.error (),
531
- " Method execution failed with status 0x%" PRIx32,
532
- static_cast <uint32_t >(result.error ()));
533
- }
534
-
535
- // Convert outputs to py::list
536
- const auto & outputs = result.get ();
537
- py::list py_outputs = get_outputs_as_py_list (outputs);
538
-
539
- Error status = BundledModule::verify_method_outputs (
540
- method_name, testset_idx, rtol, atol);
541
- THROW_IF_ERROR (
542
- status,
543
- " Result verification failed with status %" PRIu32,
544
- static_cast <uint32_t >(status));
545
- return py_outputs;
546
- }
547
-
548
478
private:
549
479
// Store the bytes object instead of a raw pointer so that this module will
550
480
// keep the bytes alive.
@@ -901,6 +831,43 @@ struct PyModule final {
901
831
}
902
832
}
903
833
834
+ void load_bundled_input (
835
+ PyBundledModule& m,
836
+ const std::string method_name,
837
+ size_t testset_idx) {
838
+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839
+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840
+ module_->get_method (method_name), bundled_program_ptr, testset_idx);
841
+ THROW_IF_ERROR (
842
+ status,
843
+ " load_bundled_input failed with status 0x%" PRIx32,
844
+ static_cast <uint32_t >(status));
845
+ }
846
+
847
+ py::list verify_result_with_bundled_expected_output (
848
+ PyBundledModule& m,
849
+ const std::string method_name,
850
+ size_t testset_idx,
851
+ double rtol = 1e-5 ,
852
+ double atol = 1e-8 ) {
853
+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854
+ auto & method = module_->get_method (method_name);
855
+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856
+ method, bundled_program_ptr, testset_idx);
857
+ THROW_IF_ERROR (
858
+ status,
859
+ " load_bundled_input failed with status 0x%" PRIx32,
860
+ static_cast <uint32_t >(status));
861
+ py::list outputs = plan_execute (method_name);
862
+ status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863
+ method, bundled_program_ptr, testset_idx, rtol, atol);
864
+ THROW_IF_ERROR (
865
+ status,
866
+ " Result verification failed with status %" PRIu32,
867
+ static_cast <uint32_t >(status));
868
+ return outputs;
869
+ }
870
+
904
871
py::list plan_execute (
905
872
const std::string method_name,
906
873
bool clone_outputs = true ) {
@@ -923,6 +890,46 @@ struct PyModule final {
923
890
return get_outputs_as_py_list (outputs, clone_outputs);
924
891
}
925
892
893
+ py::list get_outputs_as_py_list (
894
+ const std::vector<EValue>& outputs,
895
+ bool clone_outputs = true ) {
896
+ const auto outputs_size = outputs.size ();
897
+ py::list list (outputs_size);
898
+ for (size_t i = 0 ; i < outputs_size; ++i) {
899
+ auto & v = outputs[i];
900
+ if (Tag::None == v.tag ) {
901
+ list[i] = py::none ();
902
+ } else if (Tag::Int == v.tag ) {
903
+ list[i] = py::cast (v.toInt ());
904
+ } else if (Tag::Double == v.tag ) {
905
+ list[i] = py::cast (v.toDouble ());
906
+ } else if (Tag::Bool == v.tag ) {
907
+ list[i] = py::cast (v.toBool ());
908
+ } else if (Tag::String == v.tag ) {
909
+ list[i] = py::cast (std::string (v.toString ().data ()));
910
+ } else if (Tag::Tensor == v.tag ) {
911
+ #ifdef USE_ATEN_LIB
912
+ // Clone so the outputs in python do not share a lifetime with the
913
+ // module object
914
+ if (clone_outputs) {
915
+ list[i] = py::cast (v.toTensor ().clone ());
916
+ } else {
917
+ list[i] = py::cast (v.toTensor ());
918
+ }
919
+ #else
920
+ if (clone_outputs) {
921
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922
+ } else {
923
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924
+ }
925
+ #endif
926
+ } else {
927
+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928
+ }
929
+ }
930
+ return list;
931
+ }
932
+
926
933
std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
927
934
auto & method = module_->get_method (method_name);
928
935
return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1082,6 +1089,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1082
1089
call_guard);
1083
1090
1084
1091
py::class_<PyModule>(m, " ExecuTorchModule" )
1092
+ .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093
+ .def (
1094
+ " verify_result_with_bundled_expected_output" ,
1095
+ &PyModule::verify_result_with_bundled_expected_output,
1096
+ py::arg (" bundle" ),
1097
+ py::arg (" method_name" ),
1098
+ py::arg (" testset_idx" ),
1099
+ py::arg (" rtol" ) = 1e-5 ,
1100
+ py::arg (" atol" ) = 1e-8 ,
1101
+ call_guard)
1085
1102
.def (
1086
1103
" plan_execute" ,
1087
1104
&PyModule::plan_execute,
@@ -1127,16 +1144,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1127
1144
py::arg (" clone_outputs" ) = true ,
1128
1145
call_guard);
1129
1146
1130
- py::class_<PyBundledModule>(m, " BundledModule" )
1131
- .def (
1132
- " verify_result_with_bundled_expected_output" ,
1133
- &PyBundledModule::verify_result_with_bundled_expected_output,
1134
- py::arg (" method_name" ),
1135
- py::arg (" testset_idx" ),
1136
- py::arg (" rtol" ) = 1e-5 ,
1137
- py::arg (" atol" ) = 1e-8 ,
1138
- call_guard);
1139
-
1147
+ py::class_<PyBundledModule>(m, " BundledModule" );
1140
1148
py::class_<PyTensorInfo>(m, " TensorInfo" )
1141
1149
.def (" sizes" , &PyTensorInfo::sizes, call_guard)
1142
1150
.def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments