-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
I think at least the implementation could be significantly simplified by using a API based on the standard library's thread::scope mechanism.
thread::scope mechanism.use std::cell::Cell;
use std::mem::transmute;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
pub struct Scope<'scope> {
state: &'scope State,
}
impl Scope<'_> {
pub fn broadcast<F>(&self, f: F)
where
F: Fn(usize) + Sync,
{
let state = self.state;
// SAFETY: `_guard` will reset `state.work` before this function returns,
// but only after all pending workers are finished.
unsafe {
state
.work
.set(transmute::<&Work, &'static Work>(&|index, _stop| f(index)));
}
state.pending.store(self.state.workers, Ordering::Relaxed);
state.generation.fetch_add(1, Ordering::Release);
struct ResetGuard<'scope>(&'scope State);
impl Drop for ResetGuard<'_> {
fn drop(&mut self) {
let state = self.0;
while state.pending.load(Ordering::Acquire) != 0 {
thread::yield_now();
}
state.work.set(STOP);
}
}
let _guard = ResetGuard(state);
f(0);
}
}
pub fn scope<F, R>(parallelism: usize, f: F) -> R
where
F: for<'scope> FnOnce(Scope<'scope>) -> R,
{
assert!(parallelism > 1);
let state = &State {
work: Cell::new(STOP),
workers: parallelism - 1,
pending: Padded(AtomicUsize::new(0)),
generation: Padded(AtomicUsize::new(0)),
};
thread::scope(|scope| {
for index in 1..parallelism {
scope.spawn(move || state.worker(index));
}
struct StopGuard<'scope>(&'scope State);
impl Drop for StopGuard<'_> {
fn drop(&mut self) {
let state = self.0;
state.work.set(STOP);
state.generation.fetch_add(1, Ordering::Release);
}
}
let _guard = StopGuard(state);
f(Scope { state })
})
}
#[repr(align(128))]
struct State {
work: Cell<&'static Work<'static>>,
workers: usize,
pending: Padded<AtomicUsize>,
generation: Padded<AtomicUsize>,
}
unsafe impl Send for State {}
unsafe impl Sync for State {}
impl State {
fn worker(&self, index: usize) {
let mut last_generation = 0;
loop {
loop {
let curr_generation = self.generation.load(Ordering::Acquire);
if last_generation != curr_generation {
last_generation = curr_generation;
break;
} else {
thread::yield_now();
}
}
let mut stop = false;
self.work.get()(index, &mut stop);
if stop {
return;
}
self.pending.fetch_sub(1, Ordering::Release);
}
}
}
type Work<'work> = dyn Fn(usize, &mut bool) + 'work;
const STOP: &Work = &|_index, stop: &mut bool| *stop = true;
#[repr(align(128))]
struct Padded<T>(T);
impl<T> Deref for Padded<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
#[test]
fn it_works() {
let parallelism = 4;
let flags = (0..parallelism)
.map(|_| AtomicBool::new(false))
.collect::<Vec<_>>();
scope(parallelism, |scope| {
scope.broadcast(|index| {
flags[index].store(true, Ordering::Relaxed);
});
});
for flag in flags {
assert!(flag.load(Ordering::Relaxed));
}
}
}There is some other stuff in there, e.g. avoiding the manual trampoline management by relying on compiler-generated vtables, avoiding the second atomic load for the stop flag by broadcasting a STOP work item and RAII guards for unwind safety.
Also note that thread::scope currently internally and infallibly allocates a single Arc but that appears to be an implementation detail of the standard library. But then again, so does thread::Builder::spawn_unchecked, c.f. https://doc.rust-lang.org/stable/src/std/thread/mod.rs.html#512, which makes the whole allocation error checking discipline a bit moot ATM.
Metadata
Metadata
Assignees
Labels
No labels