Skip to content

Commit facd918

Browse files
committed
Avoid one atomic access per work item by stopping workers from the inside out.
1 parent 7658145 commit facd918

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

fork_union.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use std::collections::TryReserveError;
1414
use std::fmt;
1515
use std::io::Error as IoError;
1616
use std::ptr;
17-
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1818
use std::thread::{self, JoinHandle};
1919

2020
/// Pads the wrapped value to 128 bytes to avoid false sharing.
@@ -70,21 +70,24 @@ impl std::error::Error for Error {
7070
}
7171
}
7272

73-
type Trampoline = unsafe fn(*const (), usize);
73+
type Trampoline = unsafe fn(*const (), usize, &mut bool);
7474

7575
/// Dummy trampoline function as opposed to the real `worker_loop`.
76-
unsafe fn dummy_trampoline(_ctx: *const (), _index: usize) {
76+
unsafe fn dummy_trampoline(_ctx: *const (), _index: usize, _stop: &mut bool) {
7777
unreachable!("dummy_trampoline should not be called")
7878
}
7979

80+
unsafe fn stop_trampoline(_ctx: *const (), _index: usize, stop: &mut bool) {
81+
*stop = true;
82+
}
83+
8084
/// The shared state of the thread pool, used by all threads.
8185
/// It intentionally pads all of independently mutable regions to avoid false sharing.
8286
/// The `task_trampoline` function receives the `task_context` state pointers and
8387
/// some ethereal thread index similar to C-style thread pools.
8488
#[repr(align(128))]
8589
struct Inner {
8690
pub total_threads: usize,
87-
pub stop: Padded<AtomicBool>,
8891

8992
pub fork_context: *const (),
9093
pub fork_trampoline: Trampoline,
@@ -99,7 +102,6 @@ impl Inner {
99102
pub fn new(threads: usize) -> Self {
100103
Self {
101104
total_threads: threads,
102-
stop: Padded::new(AtomicBool::new(false)),
103105
fork_context: ptr::null(),
104106
fork_trampoline: dummy_trampoline,
105107
threads_to_sync: Padded::new(AtomicUsize::new(0)),
@@ -251,12 +253,12 @@ impl<A: Allocator + Clone> ThreadPool<A> {
251253
return;
252254
}
253255
assert!(self.inner.threads_to_sync.load(Ordering::SeqCst) == 0);
254-
self.inner.reset_fork();
255-
self.inner.stop.store(true, Ordering::Release);
256+
self.inner.fork_context = ptr::null();
257+
self.inner.fork_trampoline = stop_trampoline;
258+
self.inner.fork_generation.fetch_add(1, Ordering::Release);
256259
for handle in self.workers.drain(..) {
257260
let _ = handle.join();
258261
}
259-
self.inner.stop.store(false, Ordering::Relaxed);
260262
}
261263

262264
/// Executes a function on each thread of the pool.
@@ -310,33 +312,34 @@ where
310312
}
311313
}
312314

313-
unsafe fn call_lambda<F: Fn(usize)>(ctx: *const (), index: usize) {
315+
unsafe fn call_lambda<F: Fn(usize)>(ctx: *const (), index: usize, _stop: &mut bool) {
314316
let f = &*(ctx as *const F);
315317
f(index);
316318
}
317319

318320
fn worker_loop(inner: &'static Inner, thread_index: usize) {
319321
let mut last_generation = 0usize;
322+
let mut stop = false;
320323
assert!(thread_index != 0);
321324
loop {
322325
let mut new_generation;
323-
let mut wants_stop;
324326
while {
325327
new_generation = inner.fork_generation.load(Ordering::Acquire);
326-
wants_stop = inner.stop.load(Ordering::Acquire);
327-
new_generation == last_generation && !wants_stop
328+
new_generation == last_generation
328329
} {
329330
thread::yield_now();
330331
}
331-
if wants_stop {
332-
return;
333-
}
334332

335333
let trampoline = inner.trampoline();
336334
let context = inner.context();
337335
unsafe {
338-
trampoline(context, thread_index);
336+
trampoline(context, thread_index, &mut stop);
339337
}
338+
339+
if stop {
340+
return;
341+
}
342+
340343
last_generation = new_generation;
341344

342345
let before = inner.threads_to_sync.fetch_sub(1, Ordering::Release);

0 commit comments

Comments
 (0)