23
23
#include < executorch/extension/data_loader/buffer_data_loader.h>
24
24
#include < executorch/extension/data_loader/mmap_data_loader.h>
25
25
#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26
+ #include < executorch/extension/module/bundled_module.h>
26
27
#include < executorch/extension/threadpool/threadpool.h>
27
28
#include < executorch/runtime/backend/interface.h>
28
29
#include < executorch/runtime/core/data_loader.h>
@@ -96,6 +97,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
96
97
using ::executorch::extension::BufferDataLoader;
97
98
using ::executorch::extension::MallocMemoryAllocator;
98
99
using ::executorch::extension::MmapDataLoader;
100
+ using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
99
101
using ::executorch::runtime::ArrayRef;
100
102
using ::executorch::runtime::DataLoader;
101
103
using ::executorch::runtime::Error;
@@ -440,13 +442,54 @@ inline std::unique_ptr<Module> load_module_from_file(
440
442
program_verification);
441
443
}
442
444
445
+ inline py::list get_outputs_as_py_list (
446
+ const std::vector<EValue>& outputs,
447
+ bool clone_outputs = true ) {
448
+ const auto outputs_size = outputs.size ();
449
+ py::list list (outputs_size);
450
+ for (size_t i = 0 ; i < outputs_size; ++i) {
451
+ auto & v = outputs[i];
452
+ if (Tag::None == v.tag ) {
453
+ list[i] = py::none ();
454
+ } else if (Tag::Int == v.tag ) {
455
+ list[i] = py::cast (v.toInt ());
456
+ } else if (Tag::Double == v.tag ) {
457
+ list[i] = py::cast (v.toDouble ());
458
+ } else if (Tag::Bool == v.tag ) {
459
+ list[i] = py::cast (v.toBool ());
460
+ } else if (Tag::String == v.tag ) {
461
+ list[i] = py::cast (std::string (v.toString ().data ()));
462
+ } else if (Tag::Tensor == v.tag ) {
463
+ #ifdef USE_ATEN_LIB
464
+ // Clone so the outputs in python do not share a lifetime with the
465
+ // module object
466
+ if (clone_outputs) {
467
+ list[i] = py::cast (v.toTensor ().clone ());
468
+ } else {
469
+ list[i] = py::cast (v.toTensor ());
470
+ }
471
+ #else
472
+ if (clone_outputs) {
473
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
474
+ } else {
475
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
476
+ }
477
+ #endif
478
+ } else {
479
+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
480
+ }
481
+ }
482
+ return list;
483
+ }
484
+
443
485
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
444
486
445
- struct PyBundledModule final {
487
+ struct PyBundledModule : public BundledModule {
446
488
explicit PyBundledModule (
447
489
const py::bytes& buffer,
448
490
uint32_t bundled_input_pool_size)
449
- : bundled_program_ptr_(buffer),
491
+ : BundledModule(buffer.cast<std::string_view>().data()),
492
+ bundled_program_ptr_(buffer),
450
493
program_ptr_(static_cast <const void *>(
451
494
bundled_program_flatbuffer::GetBundledProgram (
452
495
get_bundled_program_ptr ())
@@ -475,6 +518,33 @@ struct PyBundledModule final {
475
518
return program_len_;
476
519
}
477
520
521
+ py::list verify_result_with_bundled_expected_output (
522
+ const std::string& method_name,
523
+ size_t testset_idx,
524
+ double rtol = 1e-5 ,
525
+ double atol = 1e-8 ) {
526
+ // Execute the method
527
+ auto result = BundledModule::execute (method_name, testset_idx);
528
+ if (!result.ok ()) {
529
+ THROW_IF_ERROR (
530
+ result.error (),
531
+ " Method execution failed with status 0x%" PRIx32,
532
+ static_cast <uint32_t >(result.error ()));
533
+ }
534
+
535
+ // Convert outputs to py::list
536
+ const auto & outputs = result.get ();
537
+ py::list py_outputs = get_outputs_as_py_list (outputs);
538
+
539
+ Error status = BundledModule::verify_method_outputs (
540
+ method_name, testset_idx, rtol, atol);
541
+ THROW_IF_ERROR (
542
+ status,
543
+ " Result verification failed with status %" PRIu32,
544
+ static_cast <uint32_t >(status));
545
+ return py_outputs;
546
+ }
547
+
478
548
private:
479
549
// Store the bytes object instead of a raw pointer so that this module will
480
550
// keep the bytes alive.
@@ -831,43 +901,6 @@ struct PyModule final {
831
901
}
832
902
}
833
903
834
- void load_bundled_input (
835
- PyBundledModule& m,
836
- const std::string method_name,
837
- size_t testset_idx) {
838
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840
- module_->get_method (method_name), bundled_program_ptr, testset_idx);
841
- THROW_IF_ERROR (
842
- status,
843
- " load_bundled_input failed with status 0x%" PRIx32,
844
- static_cast <uint32_t >(status));
845
- }
846
-
847
- py::list verify_result_with_bundled_expected_output (
848
- PyBundledModule& m,
849
- const std::string method_name,
850
- size_t testset_idx,
851
- double rtol = 1e-5 ,
852
- double atol = 1e-8 ) {
853
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854
- auto & method = module_->get_method (method_name);
855
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856
- method, bundled_program_ptr, testset_idx);
857
- THROW_IF_ERROR (
858
- status,
859
- " load_bundled_input failed with status 0x%" PRIx32,
860
- static_cast <uint32_t >(status));
861
- py::list outputs = plan_execute (method_name);
862
- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863
- method, bundled_program_ptr, testset_idx, rtol, atol);
864
- THROW_IF_ERROR (
865
- status,
866
- " Result verification failed with status %" PRIu32,
867
- static_cast <uint32_t >(status));
868
- return outputs;
869
- }
870
-
871
904
py::list plan_execute (
872
905
const std::string method_name,
873
906
bool clone_outputs = true ) {
@@ -890,46 +923,6 @@ struct PyModule final {
890
923
return get_outputs_as_py_list (outputs, clone_outputs);
891
924
}
892
925
893
- py::list get_outputs_as_py_list (
894
- const std::vector<EValue>& outputs,
895
- bool clone_outputs = true ) {
896
- const auto outputs_size = outputs.size ();
897
- py::list list (outputs_size);
898
- for (size_t i = 0 ; i < outputs_size; ++i) {
899
- auto & v = outputs[i];
900
- if (Tag::None == v.tag ) {
901
- list[i] = py::none ();
902
- } else if (Tag::Int == v.tag ) {
903
- list[i] = py::cast (v.toInt ());
904
- } else if (Tag::Double == v.tag ) {
905
- list[i] = py::cast (v.toDouble ());
906
- } else if (Tag::Bool == v.tag ) {
907
- list[i] = py::cast (v.toBool ());
908
- } else if (Tag::String == v.tag ) {
909
- list[i] = py::cast (std::string (v.toString ().data ()));
910
- } else if (Tag::Tensor == v.tag ) {
911
- #ifdef USE_ATEN_LIB
912
- // Clone so the outputs in python do not share a lifetime with the
913
- // module object
914
- if (clone_outputs) {
915
- list[i] = py::cast (v.toTensor ().clone ());
916
- } else {
917
- list[i] = py::cast (v.toTensor ());
918
- }
919
- #else
920
- if (clone_outputs) {
921
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922
- } else {
923
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924
- }
925
- #endif
926
- } else {
927
- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928
- }
929
- }
930
- return list;
931
- }
932
-
933
926
std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
934
927
auto & method = module_->get_method (method_name);
935
928
return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1089,16 +1082,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1089
1082
call_guard);
1090
1083
1091
1084
py::class_<PyModule>(m, " ExecuTorchModule" )
1092
- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093
- .def (
1094
- " verify_result_with_bundled_expected_output" ,
1095
- &PyModule::verify_result_with_bundled_expected_output,
1096
- py::arg (" bundle" ),
1097
- py::arg (" method_name" ),
1098
- py::arg (" testset_idx" ),
1099
- py::arg (" rtol" ) = 1e-5 ,
1100
- py::arg (" atol" ) = 1e-8 ,
1101
- call_guard)
1102
1085
.def (
1103
1086
" plan_execute" ,
1104
1087
&PyModule::plan_execute,
@@ -1144,7 +1127,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1144
1127
py::arg (" clone_outputs" ) = true ,
1145
1128
call_guard);
1146
1129
1147
- py::class_<PyBundledModule>(m, " BundledModule" );
1130
+ py::class_<PyBundledModule>(m, " BundledModule" )
1131
+ .def (
1132
+ " verify_result_with_bundled_expected_output" ,
1133
+ &PyBundledModule::verify_result_with_bundled_expected_output,
1134
+ py::arg (" method_name" ),
1135
+ py::arg (" testset_idx" ),
1136
+ py::arg (" rtol" ) = 1e-5 ,
1137
+ py::arg (" atol" ) = 1e-8 ,
1138
+ call_guard);
1139
+
1148
1140
py::class_<PyTensorInfo>(m, " TensorInfo" )
1149
1141
.def (" sizes" , &PyTensorInfo::sizes, call_guard)
1150
1142
.def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments