Skip to content

Commit 93e9fcd

Browse files
pytorchbotlucylq
andauthored
Add MergedDataMap to method (#12305)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12088 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/88/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/88/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/87/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/88/orig @diff-train-skip-merge --------- Co-authored-by: lucylq <[email protected]>
1 parent e13b8e4 commit 93e9fcd

File tree

4 files changed

+51
-25
lines changed

4 files changed

+51
-25
lines changed

runtime/executor/method.cpp

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/runtime/core/named_data_map.h>
2121
#include <executorch/runtime/core/span.h>
2222
#include <executorch/runtime/executor/memory_manager.h>
23+
#include <executorch/runtime/executor/merged_data_map.h>
2324
#include <executorch/runtime/executor/platform_memory_allocator.h>
2425
#include <executorch/runtime/executor/program.h>
2526
#include <executorch/runtime/executor/tensor_parser.h>
@@ -328,9 +329,9 @@ Result<size_t> Method::get_num_external_constants() {
328329
return n_external_constants;
329330
}
330331

331-
Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
332+
Error Method::parse_external_constants(const NamedDataMap* external_data_map) {
332333
ET_CHECK_OR_RETURN_ERROR(
333-
named_data_map != nullptr, InvalidState, "named_data_map is null");
334+
external_data_map != nullptr, InvalidState, "external_data_map is null");
334335
auto flatbuffer_values = serialization_plan_->values();
335336
size_t n_value = flatbuffer_values->size();
336337

@@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
372373
continue;
373374
}
374375
Result<const TensorLayout> tensor_layout =
375-
named_data_map->get_tensor_layout(key);
376+
external_data_map->get_tensor_layout(key);
376377
if (!tensor_layout.ok()) {
377378
ET_LOG(Info, "Failed to get metadata for key %s", key);
378379
return tensor_layout.error();
@@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
387388
external_constants_[n_external_constants_].key = key;
388389

389390
// Save the buffer.
390-
Result<FreeableBuffer> buffer = named_data_map->get_data(key);
391+
Result<FreeableBuffer> buffer = external_data_map->get_data(key);
391392
ET_CHECK_OR_RETURN_ERROR(
392393
buffer.ok(),
393394
InvalidExternalData,
@@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
400401
return Error::Ok;
401402
}
402403

403-
Error Method::parse_values(const NamedDataMap* named_data_map) {
404+
Error Method::parse_values(const NamedDataMap* external_data_map) {
404405
auto flatbuffer_values = serialization_plan_->values();
405406
ET_CHECK_OR_RETURN_ERROR(
406407
flatbuffer_values != nullptr, InvalidProgram, "Missing values");
@@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
428429
if (external_constants_ == nullptr) {
429430
return Error::MemoryAllocationFailed;
430431
}
431-
Error err = parse_external_constants(named_data_map);
432+
Error err = parse_external_constants(external_data_map);
432433
if (err != Error::Ok) {
433434
return err;
434435
}
@@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
541542
program_,
542543
memory_manager_,
543544
static_cast<const executorch_flatbuffer::Tensor*>(val),
544-
named_data_map,
545+
external_data_map,
545546
Span<NamedData>(external_constants_, n_external_constants_));
546547
if (!t.ok()) {
547548
ET_LOG(
@@ -741,7 +742,7 @@ Result<Method> Method::load(
741742
const Program* program,
742743
MemoryManager* memory_manager,
743744
EventTracer* event_tracer,
744-
const NamedDataMap* named_data_map) {
745+
const NamedDataMap* external_data_map) {
745746
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
746747
if (temp_allocator == nullptr) {
747748
PlatformMemoryAllocator* platform_allocator =
@@ -755,7 +756,7 @@ Result<Method> Method::load(
755756
}
756757
Method method(program, memory_manager, event_tracer, temp_allocator);
757758
ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str());
758-
Error err = method.init(s_plan, named_data_map);
759+
Error err = method.init(s_plan, external_data_map);
759760
if (err != Error::Ok) {
760761
return err;
761762
} else {
@@ -766,7 +767,7 @@ Result<Method> Method::load(
766767

767768
Error Method::init(
768769
executorch_flatbuffer::ExecutionPlan* s_plan,
769-
const NamedDataMap* named_data_map) {
770+
const NamedDataMap* external_data_map) {
770771
EXECUTORCH_SCOPE_PROF("Method::init");
771772
internal::EventTracerProfileMethodScope event_tracer_profile_scope =
772773
internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
@@ -783,7 +784,7 @@ Error Method::init(
783784

784785
{
785786
// Parse the elements of the values_ array.
786-
Error err = parse_values(named_data_map);
787+
Error err = parse_values(external_data_map);
787788
if (err != Error::Ok) {
788789
return err;
789790
}
@@ -800,21 +801,34 @@ Error Method::init(
800801
return Error::MemoryAllocationFailed;
801802
}
802803

803-
// Get NamedDataMap, if it exists.
804-
const NamedDataMap* pte_data_map = nullptr;
805-
Result<const NamedDataMap*> pte_data_map_res =
806-
program_->get_named_data_map();
807-
if (pte_data_map_res.ok()) {
808-
pte_data_map = pte_data_map_res.get();
809-
}
810-
804+
// Get PTE data map, if it exists.
805+
auto pte_data_map = program_->get_named_data_map();
811806
ET_CHECK_OR_RETURN_ERROR(
812-
!(pte_data_map && named_data_map),
813-
NotSupported,
814-
"NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues");
815-
816-
if (!named_data_map || named_data_map->get_num_keys().get() == 0) {
817-
named_data_map = pte_data_map;
807+
pte_data_map.ok() || pte_data_map.error() == Error::NotFound,
808+
InvalidProgram,
809+
"Failed to get named data map from program: 0x%" PRIx32,
810+
static_cast<uint32_t>(pte_data_map.error()));
811+
812+
const NamedDataMap* named_data_map = nullptr;
813+
if (external_data_map && pte_data_map.ok()) {
814+
// Merge external_data_map and pte_data_map if both are present.
815+
auto merged =
816+
internal::MergedDataMap::load(external_data_map, pte_data_map.get());
817+
if (!merged.ok()) {
818+
return merged.error();
819+
}
820+
// Allocate memory for the merged data map.
821+
merged_data_map_ =
822+
method_allocator->allocateInstance<internal::MergedDataMap>();
823+
if (merged_data_map_ == nullptr) {
824+
return Error::MemoryAllocationFailed;
825+
}
826+
new (merged_data_map_) internal::MergedDataMap(std::move(merged.get()));
827+
named_data_map = merged_data_map_;
828+
} else if (external_data_map) {
829+
named_data_map = external_data_map;
830+
} else if (pte_data_map.ok()) {
831+
named_data_map = pte_data_map.get();
818832
}
819833

820834
// n_delegate_ counts the number of successfully-initialized delegates for
@@ -1680,6 +1694,10 @@ Method::~Method() {
16801694
for (const auto i : c10::irange(n_external_constants_)) {
16811695
external_constants_[i].buffer.~FreeableBuffer();
16821696
}
1697+
// Free the MergedDataMap.
1698+
if (merged_data_map_ != nullptr) {
1699+
merged_data_map_->~MergedDataMap();
1700+
}
16831701
// All other fields are trivially destructible.
16841702
}
16851703
} // namespace ET_RUNTIME_NAMESPACE

runtime/executor/method.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/runtime/core/named_data_map.h>
2121
#include <executorch/runtime/core/span.h>
2222
#include <executorch/runtime/executor/memory_manager.h>
23+
#include <executorch/runtime/executor/merged_data_map.h>
2324
#include <executorch/runtime/executor/method_meta.h>
2425
#include <executorch/runtime/platform/compiler.h>
2526

@@ -76,6 +77,7 @@ class Method final {
7677
delegates_(rhs.delegates_),
7778
n_chains_(rhs.n_chains_),
7879
chains_(rhs.chains_),
80+
merged_data_map_(std::move(rhs.merged_data_map_)),
7981
external_constants_(rhs.external_constants_),
8082
n_external_constants_(rhs.n_external_constants_),
8183
init_state_(rhs.init_state_) {
@@ -85,6 +87,8 @@ class Method final {
8587
rhs.values_ = nullptr;
8688
rhs.n_delegate_ = 0;
8789
rhs.delegates_ = nullptr;
90+
91+
rhs.merged_data_map_ = nullptr;
8892
rhs.n_external_constants_ = 0;
8993
rhs.external_constants_ = nullptr;
9094

@@ -314,6 +318,7 @@ class Method final {
314318
delegates_(nullptr),
315319
n_chains_(0),
316320
chains_(nullptr),
321+
merged_data_map_(nullptr),
317322
external_constants_(nullptr),
318323
n_external_constants_(0),
319324
init_state_(InitializationState::Uninitialized) {}
@@ -364,6 +369,7 @@ class Method final {
364369
size_t n_chains_;
365370
Chain* chains_;
366371

372+
internal::MergedDataMap* merged_data_map_;
367373
NamedData* external_constants_;
368374
size_t n_external_constants_ = 0;
369375

runtime/executor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def define_common_targets():
117117
exported_deps = [
118118
":memory_manager",
119119
":pte_data_map" + aten_suffix,
120+
":merged_data_map" + aten_suffix,
120121
"//executorch/runtime/backend:interface" + aten_suffix,
121122
"//executorch/runtime/core:core",
122123
"//executorch/runtime/core:named_data_map" + aten_suffix,

runtime/executor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def define_common_targets(is_fbcode = False):
163163
],
164164
deps = [
165165
":managed_memory_manager",
166+
"//executorch/runtime/executor:merged_data_map",
166167
"//executorch/runtime/executor:program",
167168
"//executorch/extension/data_loader:file_data_loader",
168169
"//executorch/extension/flat_tensor:flat_tensor_data_map",

0 commit comments

Comments
 (0)