diff --git a/runtime/backend/interface.cpp b/runtime/backend/interface.cpp index ffeb133fbf..8af4ed0206 100644 --- a/runtime/backend/interface.cpp +++ b/runtime/backend/interface.cpp @@ -66,5 +66,42 @@ Result get_backend_name(size_t index) { return registered_backends[index].name; } +Error set_option( + const char* backend_name, + const executorch::runtime::Span + backend_options) { + auto backend_class = get_backend_class(backend_name); + if (!backend_class) { + return Error::NotFound; + } + + BackendOptionContext backend_option_context; + Error result = + backend_class->set_option(backend_option_context, backend_options); + if (result != Error::Ok) { + return result; + } + return Error::Ok; +} + + Error get_option( + const char* backend_name, + executorch::runtime::Span + backend_options) { + auto backend_class = get_backend_class(backend_name); + if (!backend_class) { + return Error::NotFound; + } + BackendOptionContext backend_option_context; + executorch::runtime::Span backend_options_ref( + backend_options.data(), backend_options.size()); + auto result = + backend_class->get_option(backend_option_context, backend_options_ref); + if (result != Error::Ok) { + return result; + } + return Error::Ok; +} + } // namespace ET_RUNTIME_NAMESPACE } // namespace executorch diff --git a/runtime/backend/interface.h b/runtime/backend/interface.h index e6a4c2fb8e..382b9e6bd1 100644 --- a/runtime/backend/interface.h +++ b/runtime/backend/interface.h @@ -183,6 +183,37 @@ size_t get_num_registered_backends(); */ Result get_backend_name(size_t index); + +/** +* Sets backend options for a specific backend. +* +* @param backend_name The name of the backend to set options for +* @param backend_options The backend option list containing the options +* to set +* @return Error::Ok on success, Error::NotFound if backend is not found, or +* other error codes on failure +*/ +Error set_option( + const char* backend_name, + const executorch::runtime::Span + backend_options); + + +/** + * Retrieves backend options for a specific backend. + * + * @param backend_name The name of the backend to get options from + * @param backend_options The backend option objects that will be filled with + * the populated values from the backend + * @return Error::Ok on success, Error::NotFound if backend is not found, or + * other error codes on failure + */ + Error get_option( + const char* backend_name, + executorch::runtime::Span + backend_options); + + } // namespace ET_RUNTIME_NAMESPACE } // namespace executorch diff --git a/runtime/backend/test/backend_interface_update_test.cpp b/runtime/backend/test/backend_interface_update_test.cpp index 27dc284af5..021fbd8d81 100644 --- a/runtime/backend/test/backend_interface_update_test.cpp +++ b/runtime/backend/test/backend_interface_update_test.cpp @@ -7,9 +7,11 @@ */ #include +#include #include #include +#include using namespace ::testing; using executorch::runtime::ArrayRef; @@ -61,7 +63,8 @@ class MockBackend : public BackendInterface { int success_update = 0; for (const auto& backend_option : backend_options) { if (strcmp(backend_option.key, "Backend") == 0) { - if (std::holds_alternative>( + if (std::holds_alternative< + std::array>( backend_option.value)) { // Store the value in our member variable const auto& arr = @@ -285,3 +288,116 @@ TEST_F(BackendInterfaceUpdateTest, UpdateBetweenExecutes) { ASSERT_TRUE(mock_backend->target_backend.has_value()); EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "NPU"); } + +// Mock backend for testing +class StubBackend : public BackendInterface { + public: + ~StubBackend() override = default; + + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + return Error::Ok; + } + + Error get_option( + BackendOptionContext& context, + executorch::runtime::Span& + backend_options) override { + // For testing purposes, just record that get_option was called + // and verify the input parameters + get_option_called = true; + get_option_call_count++; + last_get_option_size = backend_options.size(); + + // Verify that the expected option key is present and modify the value + for (size_t i = 0; i < backend_options.size(); ++i) { + if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) { + // Set the value to what was stored by set_option + backend_options[i].value = last_num_threads; + found_expected_key = true; + break; + } + } + + return Error::Ok; + } + + Error set_option( + BackendOptionContext& context, + const executorch::runtime::Span& + backend_options) override { + // Store the options for verification + last_options_size = backend_options.size(); + if (backend_options.size() > 0) { + for (const auto& option : backend_options) { + if (strcmp(option.key, "NumberOfThreads") == 0) { + if (auto* val = std::get_if(&option.value)) { + last_num_threads = *val; + } + } + } + } + return Error::Ok; + } + + // Mutable for testing verification + size_t last_options_size = 0; + int last_num_threads = 0; + bool get_option_called = false; + int get_option_call_count = 0; + size_t last_get_option_size = 0; + bool found_expected_key = false; +}; + +class BackendUpdateTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Register the stub backend + stub_backend = std::make_unique(); + Backend backend_config{"StubBackend", stub_backend.get()}; + auto register_result = register_backend(backend_config); + ASSERT_EQ(register_result, Error::Ok); + } + + std::unique_ptr stub_backend; +}; + +// Test basic string functionality +TEST_F(BackendUpdateTest, TestSetGetOption) { + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option("NumberOfThreads", new_num_threads); + + auto status = set_option("StubBackend", backend_options.view()); + ASSERT_EQ(status, Error::Ok); + + // Set up the default option, which will be populuated by the get_option call + BackendOption ref_backend_option{"NumberOfThreads", 0}; + status = get_option("StubBackend", ref_backend_option); + + // Verify that the backend actually received the options + ASSERT_TRUE( + std::get(ref_backend_option.value) == + new_num_threads); + + // Verify that the backend actually update the options + ASSERT_EQ(stub_backend->last_options_size, 1); + ASSERT_EQ(stub_backend->last_num_threads, new_num_threads); +}