Skip to content

Commit b905975

Browse files
[ExecuTorch][#10447] Extend PyBundledModule with extension.BundledModule (#11595)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11565 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/13/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/13/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/14/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/13/orig @diff-train-skip-merge --------- Co-authored-by: gasoonjia <[email protected]>
1 parent 8b17e81 commit b905975

File tree

5 files changed

+96
-122
lines changed

5 files changed

+96
-122
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,14 @@ if(EXECUTORCH_BUILD_PYBIND)
583583
torch
584584
)
585585

586+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
587+
if(CMAKE_TOOLCHAIN_IOS OR CMAKE_TOOLCHAIN_ANDROID OR APPLE)
588+
list(APPEND _dep_libs extension_module_static)
589+
else()
590+
list(APPEND _dep_libs extension_module)
591+
endif()
592+
endif()
593+
586594
if(EXECUTORCH_BUILD_TESTS)
587595
list(APPEND _dep_libs test_backend_compiler_lib)
588596
endif()

devtools/bundled_program/test/test_end2end.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# flake8: noqa: F401
8-
import functools
9-
import inspect
10-
import os
11-
import random
128
import unittest
13-
from typing import Callable, Dict, Optional, Tuple, Type
14-
15-
import executorch.exir as exir
16-
17-
import executorch.exir.control_flow as control_flow
18-
19-
# @manual=//executorch/extension/pytree:pybindings
20-
import executorch.extension.pytree as pytree
21-
22-
import torch
239

2410
from executorch.devtools.bundled_program.core import BundledProgram
2511
from executorch.devtools.bundled_program.serialize import (
@@ -35,8 +21,6 @@
3521
try:
3622
from executorch.extension.pybindings.portable_lib import (
3723
_load_bundled_program_from_buffer,
38-
_load_for_executorch_from_buffer,
39-
_load_for_executorch_from_bundled_program,
4024
)
4125

4226
kernel_mode = "lean"
@@ -47,8 +31,6 @@
4731
try:
4832
from executorch.extension.pybindings.aten_lib import ( # @manual=//executorch/extension/pybindings:aten_lib
4933
_load_bundled_program_from_buffer,
50-
_load_for_executorch_from_buffer,
51-
_load_for_executorch_from_bundled_program,
5234
)
5335

5436
assert kernel_mode is None
@@ -75,19 +57,8 @@ def test_sample_model_e2e(self):
7557
bundled_program_buffer
7658
)
7759

78-
executorch_module = _load_for_executorch_from_bundled_program(
79-
executorch_bundled_program
80-
)
81-
8260
for method_name in eager_model.method_names:
83-
executorch_module.load_bundled_input(
84-
executorch_bundled_program,
85-
method_name,
86-
0,
87-
)
88-
executorch_module.plan_execute(method_name)
89-
executorch_module.verify_result_with_bundled_expected_output(
90-
executorch_bundled_program,
61+
executorch_bundled_program.verify_result_with_bundled_expected_output(
9162
method_name,
9263
0,
9364
)

extension/pybindings/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ CMAKE_ARGS="-DEXECUTORCH_BUILD_MPS=ON" ./install_executorch.sh
2727
- `_reset_profile_results()`: Reset profile results.
2828
## Classes
2929
### ExecuTorchModule
30-
- `load_bundled_input()`: Load bundled input.
31-
- `verify_result_with_bundled_expected_output(bundle: str, method_name: str, testset_idx: int, rtol: float = 1e-5, atol: float = 1e-8)`: Verify result with bundled expected output.
3230
- `plan_execute()`: Plan and execute.
3331
- `run_method()`: Run method.
3432
- `forward()`: Forward. This takes a pytree-flattend PyTorch-tensor-based input.
@@ -37,5 +35,6 @@ CMAKE_ARGS="-DEXECUTORCH_BUILD_MPS=ON" ./install_executorch.sh
3735
- `__call__()`: Call method.
3836
### BundledModule
3937
This class is currently empty and serves as a placeholder for future methods and attributes.
38+
- `verify_result_with_bundled_expected_output(method_name: str, testset_idx: int, rtol: float = 1e-5, atol: float = 1e-8)`: Verify result with bundled expected output.
4039
## Note
4140
All functions and methods are guarded by a call guard that redirects `cout` and `cerr` to the Python environment.

extension/pybindings/pybindings.cpp

Lines changed: 82 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/threadpool/threadpool.h>
2728
#include <executorch/runtime/backend/interface.h>
2829
#include <executorch/runtime/core/data_loader.h>
@@ -96,6 +97,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
9697
using ::executorch::extension::BufferDataLoader;
9798
using ::executorch::extension::MallocMemoryAllocator;
9899
using ::executorch::extension::MmapDataLoader;
100+
using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
99101
using ::executorch::runtime::ArrayRef;
100102
using ::executorch::runtime::DataLoader;
101103
using ::executorch::runtime::Error;
@@ -440,13 +442,54 @@ inline std::unique_ptr<Module> load_module_from_file(
440442
program_verification);
441443
}
442444

445+
inline py::list get_outputs_as_py_list(
446+
const std::vector<EValue>& outputs,
447+
bool clone_outputs = true) {
448+
const auto outputs_size = outputs.size();
449+
py::list list(outputs_size);
450+
for (size_t i = 0; i < outputs_size; ++i) {
451+
auto& v = outputs[i];
452+
if (Tag::None == v.tag) {
453+
list[i] = py::none();
454+
} else if (Tag::Int == v.tag) {
455+
list[i] = py::cast(v.toInt());
456+
} else if (Tag::Double == v.tag) {
457+
list[i] = py::cast(v.toDouble());
458+
} else if (Tag::Bool == v.tag) {
459+
list[i] = py::cast(v.toBool());
460+
} else if (Tag::String == v.tag) {
461+
list[i] = py::cast(std::string(v.toString().data()));
462+
} else if (Tag::Tensor == v.tag) {
463+
#ifdef USE_ATEN_LIB
464+
// Clone so the outputs in python do not share a lifetime with the
465+
// module object
466+
if (clone_outputs) {
467+
list[i] = py::cast(v.toTensor().clone());
468+
} else {
469+
list[i] = py::cast(v.toTensor());
470+
}
471+
#else
472+
if (clone_outputs) {
473+
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
474+
} else {
475+
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
476+
}
477+
#endif
478+
} else {
479+
ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
480+
}
481+
}
482+
return list;
483+
}
484+
443485
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
444486

445-
struct PyBundledModule final {
487+
struct PyBundledModule : public BundledModule {
446488
explicit PyBundledModule(
447489
const py::bytes& buffer,
448490
uint32_t bundled_input_pool_size)
449-
: bundled_program_ptr_(buffer),
491+
: BundledModule(buffer.cast<std::string_view>().data()),
492+
bundled_program_ptr_(buffer),
450493
program_ptr_(static_cast<const void*>(
451494
bundled_program_flatbuffer::GetBundledProgram(
452495
get_bundled_program_ptr())
@@ -475,6 +518,33 @@ struct PyBundledModule final {
475518
return program_len_;
476519
}
477520

521+
py::list verify_result_with_bundled_expected_output(
522+
const std::string& method_name,
523+
size_t testset_idx,
524+
double rtol = 1e-5,
525+
double atol = 1e-8) {
526+
// Execute the method
527+
auto result = BundledModule::execute(method_name, testset_idx);
528+
if (!result.ok()) {
529+
THROW_IF_ERROR(
530+
result.error(),
531+
"Method execution failed with status 0x%" PRIx32,
532+
static_cast<uint32_t>(result.error()));
533+
}
534+
535+
// Convert outputs to py::list
536+
const auto& outputs = result.get();
537+
py::list py_outputs = get_outputs_as_py_list(outputs);
538+
539+
Error status = BundledModule::verify_method_outputs(
540+
method_name, testset_idx, rtol, atol);
541+
THROW_IF_ERROR(
542+
status,
543+
"Result verification failed with status %" PRIu32,
544+
static_cast<uint32_t>(status));
545+
return py_outputs;
546+
}
547+
478548
private:
479549
// Store the bytes object instead of a raw pointer so that this module will
480550
// keep the bytes alive.
@@ -831,43 +901,6 @@ struct PyModule final {
831901
}
832902
}
833903

834-
void load_bundled_input(
835-
PyBundledModule& m,
836-
const std::string method_name,
837-
size_t testset_idx) {
838-
const void* bundled_program_ptr = m.get_bundled_program_ptr();
839-
Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
840-
module_->get_method(method_name), bundled_program_ptr, testset_idx);
841-
THROW_IF_ERROR(
842-
status,
843-
"load_bundled_input failed with status 0x%" PRIx32,
844-
static_cast<uint32_t>(status));
845-
}
846-
847-
py::list verify_result_with_bundled_expected_output(
848-
PyBundledModule& m,
849-
const std::string method_name,
850-
size_t testset_idx,
851-
double rtol = 1e-5,
852-
double atol = 1e-8) {
853-
const void* bundled_program_ptr = m.get_bundled_program_ptr();
854-
auto& method = module_->get_method(method_name);
855-
Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
856-
method, bundled_program_ptr, testset_idx);
857-
THROW_IF_ERROR(
858-
status,
859-
"load_bundled_input failed with status 0x%" PRIx32,
860-
static_cast<uint32_t>(status));
861-
py::list outputs = plan_execute(method_name);
862-
status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
863-
method, bundled_program_ptr, testset_idx, rtol, atol);
864-
THROW_IF_ERROR(
865-
status,
866-
"Result verification failed with status %" PRIu32,
867-
static_cast<uint32_t>(status));
868-
return outputs;
869-
}
870-
871904
py::list plan_execute(
872905
const std::string method_name,
873906
bool clone_outputs = true) {
@@ -890,46 +923,6 @@ struct PyModule final {
890923
return get_outputs_as_py_list(outputs, clone_outputs);
891924
}
892925

893-
py::list get_outputs_as_py_list(
894-
const std::vector<EValue>& outputs,
895-
bool clone_outputs = true) {
896-
const auto outputs_size = outputs.size();
897-
py::list list(outputs_size);
898-
for (size_t i = 0; i < outputs_size; ++i) {
899-
auto& v = outputs[i];
900-
if (Tag::None == v.tag) {
901-
list[i] = py::none();
902-
} else if (Tag::Int == v.tag) {
903-
list[i] = py::cast(v.toInt());
904-
} else if (Tag::Double == v.tag) {
905-
list[i] = py::cast(v.toDouble());
906-
} else if (Tag::Bool == v.tag) {
907-
list[i] = py::cast(v.toBool());
908-
} else if (Tag::String == v.tag) {
909-
list[i] = py::cast(std::string(v.toString().data()));
910-
} else if (Tag::Tensor == v.tag) {
911-
#ifdef USE_ATEN_LIB
912-
// Clone so the outputs in python do not share a lifetime with the
913-
// module object
914-
if (clone_outputs) {
915-
list[i] = py::cast(v.toTensor().clone());
916-
} else {
917-
list[i] = py::cast(v.toTensor());
918-
}
919-
#else
920-
if (clone_outputs) {
921-
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
922-
} else {
923-
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
924-
}
925-
#endif
926-
} else {
927-
ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
928-
}
929-
}
930-
return list;
931-
}
932-
933926
std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
934927
auto& method = module_->get_method(method_name);
935928
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
@@ -1089,16 +1082,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10891082
call_guard);
10901083

10911084
py::class_<PyModule>(m, "ExecuTorchModule")
1092-
.def("load_bundled_input", &PyModule::load_bundled_input, call_guard)
1093-
.def(
1094-
"verify_result_with_bundled_expected_output",
1095-
&PyModule::verify_result_with_bundled_expected_output,
1096-
py::arg("bundle"),
1097-
py::arg("method_name"),
1098-
py::arg("testset_idx"),
1099-
py::arg("rtol") = 1e-5,
1100-
py::arg("atol") = 1e-8,
1101-
call_guard)
11021085
.def(
11031086
"plan_execute",
11041087
&PyModule::plan_execute,
@@ -1144,7 +1127,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11441127
py::arg("clone_outputs") = true,
11451128
call_guard);
11461129

1147-
py::class_<PyBundledModule>(m, "BundledModule");
1130+
py::class_<PyBundledModule>(m, "BundledModule")
1131+
.def(
1132+
"verify_result_with_bundled_expected_output",
1133+
&PyBundledModule::verify_result_with_bundled_expected_output,
1134+
py::arg("method_name"),
1135+
py::arg("testset_idx"),
1136+
py::arg("rtol") = 1e-5,
1137+
py::arg("atol") = 1e-8,
1138+
call_guard);
1139+
11481140
py::class_<PyTensorInfo>(m, "TensorInfo")
11491141
.def("sizes", &PyTensorInfo::sizes, call_guard)
11501142
.def("dtype", &PyTensorInfo::dtype, call_guard)

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ PORTABLE_MODULE_DEPS = [
1616
"//executorch/extension/data_loader:buffer_data_loader",
1717
"//executorch/extension/data_loader:mmap_data_loader",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
19+
"//executorch/extension/module:module",
20+
"//executorch/extension/module:bundled_module",
1921
"//executorch/runtime/executor/test:test_backend_compiler_lib",
2022
"//executorch/devtools/etdump:etdump_flatcc",
2123
] + get_all_cpu_backend_targets()
@@ -28,6 +30,8 @@ ATEN_MODULE_DEPS = [
2830
"//executorch/extension/data_loader:buffer_data_loader",
2931
"//executorch/extension/data_loader:mmap_data_loader",
3032
"//executorch/extension/memory_allocator:malloc_memory_allocator",
33+
"//executorch/extension/module:module_aten",
34+
"//executorch/extension/module:bundled_module_aten",
3135
"//executorch/devtools/bundled_program:runtime_aten",
3236
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
3337
"//executorch/devtools/etdump:etdump_flatcc",

0 commit comments

Comments
 (0)