File tree 7 files changed +55
-9
lines changed
base/include/modelbox/base 7 files changed +55
-9
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,8 @@ class Executor {
33
33
34
34
virtual ~Executor ();
35
35
36
+ void SetThreadCount (int thread_count);
37
+
36
38
template <typename func, typename ... ts>
37
39
auto Run (func &&fun, int32_t priority, ts &&...params)
38
40
-> std::future<typename std::result_of<func(ts...)>::type> {
Original file line number Diff line number Diff line change @@ -34,6 +34,10 @@ Executor::Executor(int thread_count) {
34
34
35
35
Executor::~Executor () { thread_pool_ = nullptr ; }
36
36
37
+ void Executor::SetThreadCount (int thread_count) {
38
+ thread_pool_->SetThreadSize (thread_count);
39
+ }
40
+
37
41
FlowUnitExecContext::FlowUnitExecContext (
38
42
std::shared_ptr<FlowUnitDataContext> data_ctx)
39
43
: data_ctx_(std::move(data_ctx)) {}
Original file line number Diff line number Diff line change @@ -46,6 +46,8 @@ void FlowUnitGroup::InitTrace() {
46
46
}
47
47
}
48
48
49
+ uint32_t FlowUnitGroup::GetBatchSize () const { return batch_size_; }
50
+
49
51
std::shared_ptr<TraceSlice> FlowUnitGroup::StartTrace (
50
52
FUExecContextList &exec_ctx_list) {
51
53
std::call_once (trace_init_flag_, &FlowUnitGroup::InitTrace, this );
@@ -388,7 +390,7 @@ Status FlowUnitGroup::Open(const CreateExternalDataFunc &create_func) {
388
390
389
391
return STATUS_OK;
390
392
};
391
-
393
+
392
394
ThreadPool pool (std::thread::hardware_concurrency ());
393
395
pool.SetName (unit_name_ + " -Open" );
394
396
std::vector<std::future<Status>> result;
Original file line number Diff line number Diff line change @@ -59,6 +59,14 @@ Status FlowUnitManager::Initialize(
59
59
SetDeviceManager (std::move (device_mgr));
60
60
Status status;
61
61
status = InitFlowUnitFactory (driver);
62
+
63
+ if (config != nullptr ) {
64
+ max_executor_thread_num_ =
65
+ config->GetUint32 (" graph.max_executor_thread_num" , 0 );
66
+ } else {
67
+ max_executor_thread_num_ = 0 ;
68
+ }
69
+
62
70
if (status != STATUS_SUCCESS) {
63
71
return status;
64
72
}
@@ -407,6 +415,12 @@ std::shared_ptr<FlowUnit> FlowUnitManager::CreateSingleFlowUnit(
407
415
return nullptr ;
408
416
}
409
417
418
+ if (max_executor_thread_num_ > 0 ) {
419
+ MBLOG_INFO << " find the parameter max_executor_thread_num in the config: "
420
+ << max_executor_thread_num_;
421
+ device->GetDeviceExecutor ()->SetThreadCount (max_executor_thread_num_);
422
+ }
423
+
410
424
flowunit->SetBindDevice (device);
411
425
std::vector<FlowUnitInput> &in_list = flowunit_desc->GetFlowUnitInput ();
412
426
for (auto &in_item : in_list) {
Original file line number Diff line number Diff line change @@ -763,23 +763,43 @@ void Node::CleanDataContext() {
763
763
764
764
Status Node::Run (RunType type) {
765
765
std::list<std::shared_ptr<FlowUnitDataContext>> data_ctx_list;
766
+ size_t process_count = 0 ;
766
767
auto ret = Recv (type, data_ctx_list);
767
- if (!ret) {
768
- return ret;
769
- }
770
768
771
- ret = Process (data_ctx_list);
772
769
if (!ret) {
773
770
return ret;
774
771
}
775
772
776
- if (!GetOutputNames ().empty ()) {
777
- ret = Send (data_ctx_list);
773
+ std::list<std::shared_ptr<FlowUnitDataContext>> process_ctx_list;
774
+
775
+ auto output_names_is_empty = GetOutputNames ().empty ();
776
+
777
+ for (auto & ctx : data_ctx_list) {
778
+ // process data according to batch size
779
+ process_count++;
780
+ process_ctx_list.push_back (ctx);
781
+
782
+ if (process_ctx_list.size () < flowunit_group_->GetBatchSize ()) {
783
+ if (process_count < data_ctx_list.size ()) {
784
+ continue ;
785
+ }
786
+ }
787
+
788
+ ret = Process (process_ctx_list);
778
789
if (!ret) {
779
790
return ret;
780
791
}
781
- } else {
782
- SetLastError (data_ctx_list);
792
+
793
+ if (!output_names_is_empty) {
794
+ ret = Send (process_ctx_list);
795
+ if (!ret) {
796
+ return ret;
797
+ }
798
+ } else {
799
+ SetLastError (process_ctx_list);
800
+ }
801
+
802
+ process_ctx_list.clear ();
783
803
}
784
804
785
805
Clean (data_ctx_list);
Original file line number Diff line number Diff line change @@ -612,6 +612,8 @@ class FlowUnitManager {
612
612
613
613
std::shared_ptr<DeviceManager> GetDeviceManager ();
614
614
615
+ int max_executor_thread_num_;
616
+
615
617
private:
616
618
Status CheckParams (const std::string &unit_name, const std::string &unit_type,
617
619
const std::string &unit_device_id);
Original file line number Diff line number Diff line change @@ -64,6 +64,8 @@ class FlowUnitGroup {
64
64
65
65
Status Close ();
66
66
67
+ uint32_t GetBatchSize () const ;
68
+
67
69
private:
68
70
std::weak_ptr<Node> node_;
69
71
uint32_t batch_size_;
You can’t perform that action at this time.
0 commit comments