Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exir/backend/test/demos/rpc/ExecutorBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ExecutorBackend final : public ::executorch::runtime::BackendInterface {
new (client_memory_manager)
MemoryManager(client_method_allocator, client_planned_memory);

const NamedDataMap* named_data_map = context.get_named_data_map();
NamedDataMap* named_data_map = context.get_named_data_map();
// Construct the client Method
Result<Method> method_res = client_program->load_method(
"forward",
Expand Down
6 changes: 3 additions & 3 deletions runtime/backend/backend_init_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BackendInitContext final {
MemoryAllocator* runtime_allocator,
EventTracer* event_tracer = nullptr,
const char* method_name = nullptr,
const NamedDataMap* named_data_map = nullptr)
NamedDataMap* named_data_map = nullptr)
: runtime_allocator_(runtime_allocator),
#ifdef ET_EVENT_TRACER_ENABLED
event_tracer_(event_tracer),
Expand Down Expand Up @@ -65,15 +65,15 @@ class BackendInitContext final {
/** Get the named data map from ExecuTorch runtime.
* This provides a way for backends to retrieve data blobs by key.
*/
const NamedDataMap* get_named_data_map() const {
NamedDataMap* get_named_data_map() const {
return named_data_map_;
}

private:
MemoryAllocator* runtime_allocator_ = nullptr;
EventTracer* event_tracer_ = nullptr;
const char* method_name_ = nullptr;
const NamedDataMap* named_data_map_ = nullptr;
NamedDataMap* named_data_map_ = nullptr;
};

} // namespace ET_RUNTIME_NAMESPACE
Expand Down
36 changes: 19 additions & 17 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ Result<Method> Method::load(
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer,
const NamedDataMap* named_data_map) {
NamedDataMap* named_data_map) {
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
if (temp_allocator == nullptr) {
PlatformMemoryAllocator* platform_allocator =
Expand All @@ -766,7 +766,7 @@ Result<Method> Method::load(

Error Method::init(
executorch_flatbuffer::ExecutionPlan* s_plan,
const NamedDataMap* named_data_map) {
NamedDataMap* named_data_map) {
EXECUTORCH_SCOPE_PROF("Method::init");
internal::EventTracerProfileMethodScope event_tracer_profile_scope =
internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
Expand Down Expand Up @@ -800,21 +800,23 @@ Error Method::init(
return Error::MemoryAllocationFailed;
}

// Get NamedDataMap, if it exists.
const NamedDataMap* pte_data_map = nullptr;
Result<const NamedDataMap*> pte_data_map_res =
program_->get_named_data_map();
if (pte_data_map_res.ok()) {
pte_data_map = pte_data_map_res.get();
}

ET_CHECK_OR_RETURN_ERROR(
!(pte_data_map && named_data_map),
NotSupported,
"NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues");

if (!named_data_map || named_data_map->get_num_keys().get() == 0) {
named_data_map = pte_data_map;
// Resolve NamedDataMaps.
auto pte_data_map = program_->get_named_data_map();
if (pte_data_map.ok()) {
if (named_data_map != nullptr) {
Error error = named_data_map->merge(pte_data_map.get());
ET_CHECK_OR_RETURN_ERROR(
error == Error::Ok,
InvalidExternalData,
"Failed to merge named_data_map with pte_data_map.");
} else {
named_data_map = const_cast<NamedDataMap*>(pte_data_map.get());
}
} else if (pte_data_map.error() != Error::NotFound) {
// Error::NotFound is expected if the program does not have shared data.
// In this case, expect pte_data_map to be empty/null, and we can proceed
// with the named data map only.
return pte_data_map.error();
}

// n_delegate_ counts the number of successfully-initialized delegates for
Expand Down
4 changes: 2 additions & 2 deletions runtime/executor/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class Method final {
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer,
const NamedDataMap* named_data_map);
NamedDataMap* named_data_map);

/**
* Initialize the method from its serialized representation.
Expand All @@ -333,7 +333,7 @@ class Method final {
*/
ET_NODISCARD Error init(
executorch_flatbuffer::ExecutionPlan* s_plan,
const NamedDataMap* named_data_map);
NamedDataMap* named_data_map);

/// Returns true if the Method was successfully initialized.
inline bool initialized() const {
Expand Down
6 changes: 3 additions & 3 deletions runtime/executor/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Result<Method> Program::load_method(
const char* method_name,
MemoryManager* memory_manager,
EventTracer* event_tracer,
const NamedDataMap* named_data_map) const {
NamedDataMap* named_data_map) const {
EXECUTORCH_SCOPE_PROF("Program::load_method");
internal::event_tracer_create_event_block(event_tracer, "Default");
internal::EventTracerProfileMethodScope event_tracer_scope =
Expand Down Expand Up @@ -372,9 +372,9 @@ Result<const void*> Program::get_constant_buffer_data(
}
}

Result<const NamedDataMap*> Program::get_named_data_map() const {
Result<NamedDataMap*> Program::get_named_data_map() const {
if (pte_data_map_.has_value()) {
return &pte_data_map_.value();
return const_cast<internal::PteDataMap*>(&pte_data_map_.value());
}
return Error::NotFound;
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/executor/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Program final {
* Get the named data map from the program.
* @return The named data map.
*/
Result<const NamedDataMap*> get_named_data_map() const;
Result<NamedDataMap*> get_named_data_map() const;

/**
* Returns the number of methods in the program.
Expand Down Expand Up @@ -148,7 +148,7 @@ class Program final {
const char* method_name,
MemoryManager* memory_manager,
EventTracer* event_tracer = nullptr,
const NamedDataMap* named_data_map = nullptr) const;
NamedDataMap* named_data_map = nullptr) const;

/**
* Gathers metadata for the named method.
Expand Down
2 changes: 1 addition & 1 deletion runtime/executor/test/program_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ TEST_F(ProgramTest, GetNamedDataMap_Fail) {

// Get the named data map. Expect to fail, as add.pte does not have any
// named data segments.
Result<const executorch::runtime::NamedDataMap*> named_data_map =
Result<executorch::runtime::NamedDataMap*> named_data_map =
program->get_named_data_map();
EXPECT_EQ(named_data_map.error(), Error::NotFound);
}
Expand Down
Loading