20
20
#include < executorch/runtime/core/named_data_map.h>
21
21
#include < executorch/runtime/core/span.h>
22
22
#include < executorch/runtime/executor/memory_manager.h>
23
+ #include < executorch/runtime/executor/merged_data_map.h>
23
24
#include < executorch/runtime/executor/platform_memory_allocator.h>
24
25
#include < executorch/runtime/executor/program.h>
25
26
#include < executorch/runtime/executor/tensor_parser.h>
@@ -328,9 +329,9 @@ Result<size_t> Method::get_num_external_constants() {
328
329
return n_external_constants;
329
330
}
330
331
331
- Error Method::parse_external_constants (const NamedDataMap* named_data_map ) {
332
+ Error Method::parse_external_constants (const NamedDataMap* external_data_map ) {
332
333
ET_CHECK_OR_RETURN_ERROR (
333
- named_data_map != nullptr , InvalidState, " named_data_map is null" );
334
+ external_data_map != nullptr , InvalidState, " external_data_map is null" );
334
335
auto flatbuffer_values = serialization_plan_->values ();
335
336
size_t n_value = flatbuffer_values->size ();
336
337
@@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
372
373
continue ;
373
374
}
374
375
Result<const TensorLayout> tensor_layout =
375
- named_data_map ->get_tensor_layout (key);
376
+ external_data_map ->get_tensor_layout (key);
376
377
if (!tensor_layout.ok ()) {
377
378
ET_LOG (Info, " Failed to get metadata for key %s" , key);
378
379
return tensor_layout.error ();
@@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
387
388
external_constants_[n_external_constants_].key = key;
388
389
389
390
// Save the buffer.
390
- Result<FreeableBuffer> buffer = named_data_map ->get_data (key);
391
+ Result<FreeableBuffer> buffer = external_data_map ->get_data (key);
391
392
ET_CHECK_OR_RETURN_ERROR (
392
393
buffer.ok (),
393
394
InvalidExternalData,
@@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
400
401
return Error::Ok;
401
402
}
402
403
403
- Error Method::parse_values (const NamedDataMap* named_data_map ) {
404
+ Error Method::parse_values (const NamedDataMap* external_data_map ) {
404
405
auto flatbuffer_values = serialization_plan_->values ();
405
406
ET_CHECK_OR_RETURN_ERROR (
406
407
flatbuffer_values != nullptr , InvalidProgram, " Missing values" );
@@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
428
429
if (external_constants_ == nullptr ) {
429
430
return Error::MemoryAllocationFailed;
430
431
}
431
- Error err = parse_external_constants (named_data_map );
432
+ Error err = parse_external_constants (external_data_map );
432
433
if (err != Error::Ok) {
433
434
return err;
434
435
}
@@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
541
542
program_,
542
543
memory_manager_,
543
544
static_cast <const executorch_flatbuffer::Tensor*>(val),
544
- named_data_map ,
545
+ external_data_map ,
545
546
Span<NamedData>(external_constants_, n_external_constants_));
546
547
if (!t.ok ()) {
547
548
ET_LOG (
@@ -741,7 +742,7 @@ Result<Method> Method::load(
741
742
const Program* program,
742
743
MemoryManager* memory_manager,
743
744
EventTracer* event_tracer,
744
- const NamedDataMap* named_data_map ) {
745
+ const NamedDataMap* external_data_map ) {
745
746
MemoryAllocator* temp_allocator = memory_manager->temp_allocator ();
746
747
if (temp_allocator == nullptr ) {
747
748
PlatformMemoryAllocator* platform_allocator =
@@ -755,7 +756,7 @@ Result<Method> Method::load(
755
756
}
756
757
Method method (program, memory_manager, event_tracer, temp_allocator);
757
758
ET_LOG (Debug, " Loading method: %s." , s_plan->name ()->c_str ());
758
- Error err = method.init (s_plan, named_data_map );
759
+ Error err = method.init (s_plan, external_data_map );
759
760
if (err != Error::Ok) {
760
761
return err;
761
762
} else {
@@ -766,7 +767,7 @@ Result<Method> Method::load(
766
767
767
768
Error Method::init (
768
769
executorch_flatbuffer::ExecutionPlan* s_plan,
769
- const NamedDataMap* named_data_map ) {
770
+ const NamedDataMap* external_data_map ) {
770
771
EXECUTORCH_SCOPE_PROF (" Method::init" );
771
772
internal::EventTracerProfileMethodScope event_tracer_profile_scope =
772
773
internal::EventTracerProfileMethodScope (event_tracer_, " Method::init" );
@@ -783,7 +784,7 @@ Error Method::init(
783
784
784
785
{
785
786
// Parse the elements of the values_ array.
786
- Error err = parse_values (named_data_map );
787
+ Error err = parse_values (external_data_map );
787
788
if (err != Error::Ok) {
788
789
return err;
789
790
}
@@ -800,21 +801,34 @@ Error Method::init(
800
801
return Error::MemoryAllocationFailed;
801
802
}
802
803
803
- // Get NamedDataMap, if it exists.
804
- const NamedDataMap* pte_data_map = nullptr ;
805
- Result<const NamedDataMap*> pte_data_map_res =
806
- program_->get_named_data_map ();
807
- if (pte_data_map_res.ok ()) {
808
- pte_data_map = pte_data_map_res.get ();
809
- }
810
-
804
+ // Get PTE data map, if it exists.
805
+ auto pte_data_map = program_->get_named_data_map ();
811
806
ET_CHECK_OR_RETURN_ERROR (
812
- !(pte_data_map && named_data_map),
813
- NotSupported,
814
- " NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues" );
815
-
816
- if (!named_data_map || named_data_map->get_num_keys ().get () == 0 ) {
817
- named_data_map = pte_data_map;
807
+ pte_data_map.ok () || pte_data_map.error () == Error::NotFound,
808
+ InvalidProgram,
809
+ " Failed to get named data map from program: 0x%" PRIx32,
810
+ static_cast <uint32_t >(pte_data_map.error ()));
811
+
812
+ const NamedDataMap* named_data_map = nullptr ;
813
+ if (external_data_map && pte_data_map.ok ()) {
814
+ // Merge external_data_map and pte_data_map if both are present.
815
+ auto merged =
816
+ internal::MergedDataMap::load (external_data_map, pte_data_map.get ());
817
+ if (!merged.ok ()) {
818
+ return merged.error ();
819
+ }
820
+ // Allocate memory for the merged data map.
821
+ merged_data_map_ =
822
+ method_allocator->allocateInstance <internal::MergedDataMap>();
823
+ if (merged_data_map_ == nullptr ) {
824
+ return Error::MemoryAllocationFailed;
825
+ }
826
+ new (merged_data_map_) internal::MergedDataMap (std::move (merged.get ()));
827
+ named_data_map = merged_data_map_;
828
+ } else if (external_data_map) {
829
+ named_data_map = external_data_map;
830
+ } else if (pte_data_map.ok ()) {
831
+ named_data_map = pte_data_map.get ();
818
832
}
819
833
820
834
// n_delegate_ counts the number of successfully-initialized delegates for
@@ -1680,6 +1694,10 @@ Method::~Method() {
1680
1694
for (const auto i : c10::irange (n_external_constants_)) {
1681
1695
external_constants_[i].buffer .~FreeableBuffer ();
1682
1696
}
1697
+ // Free the MergedDataMap.
1698
+ if (merged_data_map_ != nullptr ) {
1699
+ merged_data_map_->~MergedDataMap ();
1700
+ }
1683
1701
// All other fields are trivially destructible.
1684
1702
}
1685
1703
} // namespace ET_RUNTIME_NAMESPACE
0 commit comments