diff --git a/src/libmodelbox/base/include/modelbox/base/executor.h b/src/libmodelbox/base/include/modelbox/base/executor.h index 84f27253f..b0eaaddea 100644 --- a/src/libmodelbox/base/include/modelbox/base/executor.h +++ b/src/libmodelbox/base/include/modelbox/base/executor.h @@ -33,6 +33,8 @@ class Executor { virtual ~Executor(); + void SetThreadCount(int thread_count); + template auto Run(func &&fun, int32_t priority, ts &&...params) -> std::future::type> { diff --git a/src/libmodelbox/engine/flowunit_data_executor.cc b/src/libmodelbox/engine/flowunit_data_executor.cc index e6633d3e4..2eeed9d0a 100644 --- a/src/libmodelbox/engine/flowunit_data_executor.cc +++ b/src/libmodelbox/engine/flowunit_data_executor.cc @@ -34,6 +34,10 @@ Executor::Executor(int thread_count) { Executor::~Executor() { thread_pool_ = nullptr; } +void Executor::SetThreadCount(int thread_count){ + thread_pool_->SetThreadSize(thread_count); +} + FlowUnitExecContext::FlowUnitExecContext( std::shared_ptr data_ctx) : data_ctx_(std::move(data_ctx)) {} diff --git a/src/libmodelbox/engine/flowunit_manager.cc b/src/libmodelbox/engine/flowunit_manager.cc index a72786ef9..52bc05507 100644 --- a/src/libmodelbox/engine/flowunit_manager.cc +++ b/src/libmodelbox/engine/flowunit_manager.cc @@ -59,6 +59,13 @@ Status FlowUnitManager::Initialize( SetDeviceManager(std::move(device_mgr)); Status status; status = InitFlowUnitFactory(driver); + + if (config != nullptr){ + max_executor_thread_num = config->GetUint32("graph.max_executor_thread_num", 0); + } else { + max_executor_thread_num = 0; + } + if (status != STATUS_SUCCESS) { return status; } @@ -407,6 +414,9 @@ std::shared_ptr FlowUnitManager::CreateSingleFlowUnit( return nullptr; } + MBLOG_INFO << "max_executor_thread_num: " << max_executor_thread_num; + device->GetDeviceExecutor()->SetThreadCount(max_executor_thread_num); + flowunit->SetBindDevice(device); std::vector &in_list = flowunit_desc->GetFlowUnitInput(); for (auto &in_item : in_list) { diff --git a/src/libmodelbox/include/modelbox/flowunit.h b/src/libmodelbox/include/modelbox/flowunit.h index bb8a8726f..846cfa2f7 100644 --- a/src/libmodelbox/include/modelbox/flowunit.h +++ b/src/libmodelbox/include/modelbox/flowunit.h @@ -612,6 +612,8 @@ class FlowUnitManager { std::shared_ptr GetDeviceManager(); + int max_executor_thread_num; + private: Status CheckParams(const std::string &unit_name, const std::string &unit_type, const std::string &unit_device_id);