From facd918049019ca599735c0df79b163ecd2a8e3f Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sat, 31 May 2025 22:03:36 +0200 Subject: [PATCH] Avoid one atomic access per work item by stopping workers from the inside out. --- fork_union.rs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/fork_union.rs b/fork_union.rs index 4037782..33a53be 100644 --- a/fork_union.rs +++ b/fork_union.rs @@ -14,7 +14,7 @@ use std::collections::TryReserveError; use std::fmt; use std::io::Error as IoError; use std::ptr; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread::{self, JoinHandle}; /// Pads the wrapped value to 128 bytes to avoid false sharing. @@ -70,13 +70,17 @@ impl std::error::Error for Error { } } -type Trampoline = unsafe fn(*const (), usize); +type Trampoline = unsafe fn(*const (), usize, &mut bool); /// Dummy trampoline function as opposed to the real `worker_loop`. -unsafe fn dummy_trampoline(_ctx: *const (), _index: usize) { +unsafe fn dummy_trampoline(_ctx: *const (), _index: usize, _stop: &mut bool) { unreachable!("dummy_trampoline should not be called") } +unsafe fn stop_trampoline(_ctx: *const (), _index: usize, stop: &mut bool) { + *stop = true; +} + /// The shared state of the thread pool, used by all threads. /// It intentionally pads all of independently mutable regions to avoid false sharing. /// The `task_trampoline` function receives the `task_context` state pointers and @@ -84,7 +88,6 @@ unsafe fn dummy_trampoline(_ctx: *const (), _index: usize) { #[repr(align(128))] struct Inner { pub total_threads: usize, - pub stop: Padded, pub fork_context: *const (), pub fork_trampoline: Trampoline, @@ -99,7 +102,6 @@ impl Inner { pub fn new(threads: usize) -> Self { Self { total_threads: threads, - stop: Padded::new(AtomicBool::new(false)), fork_context: ptr::null(), fork_trampoline: dummy_trampoline, threads_to_sync: Padded::new(AtomicUsize::new(0)), @@ -251,12 +253,12 @@ impl ThreadPool { return; } assert!(self.inner.threads_to_sync.load(Ordering::SeqCst) == 0); - self.inner.reset_fork(); - self.inner.stop.store(true, Ordering::Release); + self.inner.fork_context = ptr::null(); + self.inner.fork_trampoline = stop_trampoline; + self.inner.fork_generation.fetch_add(1, Ordering::Release); for handle in self.workers.drain(..) { let _ = handle.join(); } - self.inner.stop.store(false, Ordering::Relaxed); } /// Executes a function on each thread of the pool. @@ -310,33 +312,34 @@ where } } -unsafe fn call_lambda(ctx: *const (), index: usize) { +unsafe fn call_lambda(ctx: *const (), index: usize, _stop: &mut bool) { let f = &*(ctx as *const F); f(index); } fn worker_loop(inner: &'static Inner, thread_index: usize) { let mut last_generation = 0usize; + let mut stop = false; assert!(thread_index != 0); loop { let mut new_generation; - let mut wants_stop; while { new_generation = inner.fork_generation.load(Ordering::Acquire); - wants_stop = inner.stop.load(Ordering::Acquire); - new_generation == last_generation && !wants_stop + new_generation == last_generation } { thread::yield_now(); } - if wants_stop { - return; - } let trampoline = inner.trampoline(); let context = inner.context(); unsafe { - trampoline(context, thread_index); + trampoline(context, thread_index, &mut stop); } + + if stop { + return; + } + last_generation = new_generation; let before = inner.threads_to_sync.fetch_sub(1, Ordering::Release);