Skip to content

[7/N] Add sugar syntax for module.update #11534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/cccclai/27/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,46 @@ runtime::Error Module::update(
return method->update(backend_options);
}

runtime::Error Module::update(
runtime::ArrayRef<runtime::Entry> backend_options) {
return update("forward", backend_options);
}

runtime::Error Module::update(
const std::string& method_name,
const std::unordered_map<std::string, std::vector<runtime::BackendOption>>&
backend_options) {
std::vector<runtime::Entry> entries;
entries.reserve(backend_options.size());

for (const auto& [backend_name, options] : backend_options) {
entries.push_back(
{backend_name.c_str(),
runtime::ArrayRef<runtime::BackendOption>(
options.data(), options.size())});
}

return update(
method_name,
runtime::ArrayRef<runtime::Entry>(entries.data(), entries.size()));
}

runtime::Error Module::update(
const std::unordered_map<std::string, std::vector<runtime::BackendOption>>&
backend_options) {
std::vector<runtime::Entry> entries;
entries.reserve(backend_options.size());

for (const auto& [backend_name, options] : backend_options) {
entries.push_back(
{backend_name.c_str(),
runtime::ArrayRef<runtime::BackendOption>(
options.data(), options.size())});
}

return update(
runtime::ArrayRef<runtime::Entry>(entries.data(), entries.size()));
}
} // namespace ET_MODULE_NAMESPACE
} // namespace extension
} // namespace executorch
41 changes: 37 additions & 4 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <unordered_set>
#include <vector>

#include <executorch/runtime/backend/backend_options.h>
#include <executorch/runtime/backend/backend_options_map.h>
#include <executorch/runtime/executor/program.h>

#ifdef USE_ATEN_LIB
Expand Down Expand Up @@ -487,10 +489,41 @@ class Module {
*
* @returns An Error to indicate success or failure.
*/
ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update(
runtime::ArrayRef<runtime::Entry> backend_options) {
return update("forward", backend_options);
}
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
runtime::ArrayRef<runtime::Entry> backend_options);

/**
* EXPERIMENTAL: Updates backend options for a specific method.
* Loads the program and method before updating if needed. It uses simple
* std library like unordered_map to store backend options.
*
* @param[in] method_name The name of the method to update.
* @param[in] backend_options A map of <backend_name,
* vector<backend_options>>.
*
* @returns An Error to indicate success or failure.
*/
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
const std::string& method_name,
const std::unordered_map<
std::string,
std::vector<runtime::BackendOption>>& backend_options);

/**
* EXPERIMENTAL: Updates backend options for a specific method.
* Loads the program and method before updating if needed. It uses simple
* std library like unordered_map to store backend options.
*
* @param[in] method_name The name of the method to update.
* @param[in] backend_options A map of <backend_name,
* vector<backend_options>>.
*
* @returns An Error to indicate success or failure.
*/
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
const std::unordered_map<
std::string,
std::vector<runtime::BackendOption>>& backend_options);

/**
* Retrieves the EventTracer instance being used by the Module.
Expand Down
21 changes: 17 additions & 4 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/backend/backend_options.h>
#include <executorch/runtime/backend/backend_options_map.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/executor/test/stub_backend.h>

using namespace ::executorch::extension;
Expand All @@ -33,7 +33,7 @@ class ModuleTest : public ::testing::Test {
add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
stub_model_path_ = std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH");

// Register the StubBackend for testing
StubBackend::register_singleton();
}
Expand Down Expand Up @@ -492,7 +492,6 @@ TEST_F(ModuleTest, TestUpdate) {
EXPECT_EQ(update_result, Error::Ok);

ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);

}

TEST_F(ModuleTest, TestUpdateNonExistentMethod) {
Expand All @@ -503,8 +502,22 @@ TEST_F(ModuleTest, TestUpdateNonExistentMethod) {
int new_num_threads = 4;
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
map.add("StubBackend", backend_options.view());

// Test update method with non-existent method name
const auto update_result = module.update("nonexistent", map.entries());
EXPECT_NE(update_result, Error::Ok);
}

TEST_F(ModuleTest, TestUpdateSugarSyntax) {
Module module(stub_model_path_);
int new_num_threads = 4;

// Using std::unordered_map and std::vector directly
std::unordered_map<std::string, std::vector<BackendOption>> backend_options =
{{"StubBackend", {{"NumberOfThreads", new_num_threads}}}};

const auto update_result = module.update("forward", backend_options);

EXPECT_EQ(update_result, Error::Ok);
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
}
Loading