From 082d7fbcab5e7aab5a57b02a638d44c47665dda1 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 10 Jun 2025 11:39:13 -0700 Subject: [PATCH] [7/N] Add sugar syntax for module.update The update API in method is supposed to be portable, but we can make it more user friendly for the update API in module. Add a bit sugar syntax in module to improve UX. Such that user can update backend option in module like following: ``` Module module(stub_model_path_); int new_num_threads = 4; const auto update_result = module.update("forward", { {"StubBackend", {{IntKey("NumberOfThreads"), new_num_threads}} }, ); ``` Differential Revision: [D76242292](https://our.internmc.facebook.com/intern/diff/D76242292/) [ghstack-poisoned] --- .../module/dynamic_backend_options_map.h | 43 +++++++++++++++++++ extension/module/module.h | 7 +++ extension/module/targets.bzl | 1 + extension/module/test/module_test.cpp | 15 +++++++ 4 files changed, 66 insertions(+) create mode 100644 extension/module/dynamic_backend_options_map.h diff --git a/extension/module/dynamic_backend_options_map.h b/extension/module/dynamic_backend_options_map.h new file mode 100644 index 00000000000..7297d6ca01f --- /dev/null +++ b/extension/module/dynamic_backend_options_map.h @@ -0,0 +1,43 @@ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +class DynamicBackendOptionsMap { + public: + using OptionList = std::initializer_list; + + DynamicBackendOptionsMap( + std::initializer_list> list) { + entries_.reserve(list.size()); + for (const auto& item : list) { + // Store backend name + backend_names_.push_back(item.first); + // Store options + options_storage_.push_back(std::vector(item.second)); + // Create Entry with stable references + entries_.push_back({ + backend_names_.back().c_str(), + ArrayRef(options_storage_.back().data(), options_storage_.back().size()) + }); + } + } + + ArrayRef entries() const { + return ArrayRef(entries_.data(), entries_.size()); + } + + private: + std::vector backend_names_; + std::vector> options_storage_; + std::vector entries_; +}; + +} // namespace runtime +} // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index dbc4e692636..86558fe890a 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -15,6 +15,7 @@ #include #include +#include namespace executorch { namespace extension { @@ -483,6 +484,12 @@ class Module { return update("forward", backend_options); } + ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update( + const std::string& method_name, + const runtime::DynamicBackendOptionsMap& backend_options) { + return update(method_name, backend_options.entries()); + } + /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index 3e449da5e14..8f413b5dca6 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): ], exported_headers = [ "module.h", + "dynamic_backend_options_map.h", ], visibility = [ "@EXECUTORCH_CLIENTS", diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 24476c4adab..25fa990aa5a 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -508,3 +508,18 @@ TEST_F(ModuleTest, TestUpdateNonExistentMethod) { const auto update_result = module.update("nonexistent", map.entries()); EXPECT_NE(update_result, Error::Ok); } + +TEST_F(ModuleTest, TestUpdateSugarSyntax) { + Module module(stub_model_path_); + int new_num_threads = 4; + + // Clean sugar syntax + const auto update_result = module.update("forward", + { + {"StubBackend", {{IntKey("NumberOfThreads"), new_num_threads}} + }, + ); + + EXPECT_EQ(update_result, Error::Ok); + ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); +}