11
11
#include < atomic>
12
12
#include < condition_variable>
13
13
#include < cstdlib>
14
+ #include < forward_list>
14
15
#include < functional>
15
16
#include < future>
16
- #include < iostream >
17
+ #include < iterator >
17
18
#include < mutex>
18
19
#include < numeric>
19
20
#include < queue>
@@ -30,18 +31,14 @@ namespace detail {
30
31
class worker_thread {
31
32
public:
32
33
// Initializes state, but does not start the worker thread
33
- worker_thread () noexcept : m_isRunning(false ), m_numTasks(0 ) {}
34
-
35
- // Creates and launches the worker thread
36
- inline void start (size_t threadId) {
34
+ worker_thread (size_t threadId) noexcept
35
+ : m_threadId(threadId), m_isRunning(false ), m_numTasks(0 ) {
37
36
std::lock_guard<std::mutex> lock (m_workMutex);
38
37
if (this ->is_running ()) {
39
38
return ;
40
39
}
41
- m_threadId = threadId;
42
40
m_worker = std::thread ([this ]() {
43
41
while (true ) {
44
- // pin the thread to the cpu
45
42
std::unique_lock<std::mutex> lock (m_workMutex);
46
43
// Wait until there's work available
47
44
m_startWorkCondition.wait (
@@ -51,7 +48,7 @@ class worker_thread {
51
48
break ;
52
49
}
53
50
// Retrieve a task from the queue
54
- auto task = m_tasks.front ();
51
+ worker_task_t task = std::move ( m_tasks.front () );
55
52
m_tasks.pop ();
56
53
57
54
// Not modifying internal state anymore, can release the mutex
@@ -63,7 +60,7 @@ class worker_thread {
63
60
}
64
61
});
65
62
66
- m_isRunning = true ;
63
+ m_isRunning. store ( true , std::memory_order_release) ;
67
64
}
68
65
69
66
inline void schedule (const worker_task_t &task) {
@@ -79,16 +76,12 @@ class worker_thread {
79
76
size_t num_pending_tasks () const noexcept {
80
77
// m_numTasks is an atomic counter because we don't want to lock the mutex
81
78
// here, num_pending_tasks is only used for heuristics
82
- return m_numTasks.load ();
79
+ return m_numTasks.load (std::memory_order_acquire );
83
80
}
84
81
85
82
// Waits for all tasks to finish and destroys the worker thread
86
83
inline void stop () {
87
- {
88
- // Notify the worker thread to stop executing
89
- std::lock_guard<std::mutex> lock (m_workMutex);
90
- m_isRunning = false ;
91
- }
84
+ m_isRunning.store (false , std::memory_order_release);
92
85
m_startWorkCondition.notify_all ();
93
86
if (m_worker.joinable ()) {
94
87
// Wait for the worker thread to finish handling the task queue
@@ -97,18 +90,21 @@ class worker_thread {
97
90
}
98
91
99
92
// Checks whether the thread pool is currently running threads
100
- inline bool is_running () const noexcept { return m_isRunning; }
93
+ inline bool is_running () const noexcept {
94
+ return m_isRunning.load (std::memory_order_acquire);
95
+ }
101
96
102
97
private:
103
98
// Unique ID identifying the thread in the threadpool
104
- size_t m_threadId;
99
+ const size_t m_threadId;
100
+
105
101
std::thread m_worker;
106
102
107
103
std::mutex m_workMutex;
108
104
109
105
std::condition_variable m_startWorkCondition;
110
106
111
- bool m_isRunning;
107
+ std::atomic< bool > m_isRunning;
112
108
113
109
std::queue<worker_task_t > m_tasks;
114
110
@@ -121,47 +117,21 @@ class worker_thread {
121
117
// parameters and futures.
122
118
class simple_thread_pool {
123
119
public:
124
- simple_thread_pool (size_t numThreads = 0 ) noexcept : m_isRunning(false ) {
125
- this ->resize (numThreads);
126
- this ->start ();
127
- }
128
-
129
- ~simple_thread_pool () { this ->stop (); }
130
-
131
- // Creates and launches the worker threads
132
- inline void start () {
133
- if (this ->is_running ()) {
134
- return ;
135
- }
136
- size_t threadId = 0 ;
137
- for (auto &t : m_workers) {
138
- t.start (threadId);
139
- threadId++;
120
+ simple_thread_pool () noexcept
121
+ : m_isRunning(false ), m_numThreads(get_num_threads()) {
122
+ for (size_t i = 0 ; i < m_numThreads; i++) {
123
+ m_workers.emplace_front (i);
140
124
}
141
125
m_isRunning.store (true , std::memory_order_release);
142
126
}
143
127
144
- // Waits for all tasks to finish and destroys the worker threads
145
- inline void stop () {
128
+ ~simple_thread_pool () {
146
129
for (auto &t : m_workers) {
147
130
t.stop ();
148
131
}
149
132
m_isRunning.store (false , std::memory_order_release);
150
133
}
151
134
152
- inline void resize (size_t numThreads) {
153
- char *envVar = std::getenv (" SYCL_NATIVE_CPU_HOST_THREADS" );
154
- if (envVar) {
155
- numThreads = std::stoul (envVar);
156
- }
157
- if (numThreads == 0 ) {
158
- numThreads = std::thread::hardware_concurrency ();
159
- }
160
- if (!this ->is_running () && (numThreads != this ->num_threads ())) {
161
- m_workers = decltype (m_workers)(numThreads);
162
- }
163
- }
164
-
165
135
inline void schedule (const worker_task_t &task) {
166
136
// Schedule the task on the best available worker thread
167
137
this ->best_worker ().schedule (task);
@@ -171,7 +141,7 @@ class simple_thread_pool {
171
141
return m_isRunning.load (std::memory_order_acquire);
172
142
}
173
143
174
- inline size_t num_threads () const noexcept { return m_workers. size () ; }
144
+ inline size_t num_threads () const noexcept { return m_numThreads ; }
175
145
176
146
inline size_t num_pending_tasks () const noexcept {
177
147
return std::accumulate (std::begin (m_workers), std::end (m_workers),
@@ -201,24 +171,32 @@ class simple_thread_pool {
201
171
}
202
172
203
173
private:
204
- std::vector<worker_thread> m_workers;
174
+ static size_t get_num_threads () {
175
+ size_t numThreads;
176
+ char *envVar = std::getenv (" SYCL_NATIVE_CPU_HOST_THREADS" );
177
+ if (envVar) {
178
+ numThreads = std::stoul (envVar);
179
+ } else {
180
+ numThreads = std::thread::hardware_concurrency ();
181
+ }
182
+ return numThreads;
183
+ }
184
+
185
+ std::forward_list<worker_thread> m_workers;
205
186
206
187
std::atomic<bool > m_isRunning;
188
+
189
+ const size_t m_numThreads;
207
190
};
208
191
} // namespace detail
209
192
210
193
template <typename ThreadPoolT> class threadpool_interface {
211
194
ThreadPoolT threadpool;
212
195
213
196
public:
214
- void start () { threadpool.start (); }
215
-
216
- void stop () { threadpool.stop (); }
217
-
218
197
size_t num_threads () const noexcept { return threadpool.num_threads (); }
219
198
220
- threadpool_interface (size_t numThreads) : threadpool(numThreads) {}
221
- threadpool_interface () : threadpool(0 ) {}
199
+ threadpool_interface () : threadpool() {}
222
200
223
201
auto schedule_task (worker_task_t &&task) {
224
202
auto workerTask = std::make_shared<std::packaged_task<void (size_t )>>(
0 commit comments