Skip to content

Consider using a scope-based API #10

@adamreichold

Description

@adamreichold

I think at least the implementation could be significantly simplified by using a API based on the standard library's thread::scope mechanism.

use std::cell::Cell;
use std::mem::transmute;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;

pub struct Scope<'scope> {
    state: &'scope State,
}

impl Scope<'_> {
    pub fn broadcast<F>(&self, f: F)
    where
        F: Fn(usize) + Sync,
    {
        let state = self.state;

        // SAFETY: `_guard` will reset `state.work` before this function returns,
        // but only after all pending workers are finished.
        unsafe {
            state
                .work
                .set(transmute::<&Work, &'static Work>(&|index, _stop| f(index)));
        }

        state.pending.store(self.state.workers, Ordering::Relaxed);
        state.generation.fetch_add(1, Ordering::Release);

        struct ResetGuard<'scope>(&'scope State);

        impl Drop for ResetGuard<'_> {
            fn drop(&mut self) {
                let state = self.0;

                while state.pending.load(Ordering::Acquire) != 0 {
                    thread::yield_now();
                }

                state.work.set(STOP);
            }
        }

        let _guard = ResetGuard(state);

        f(0);
    }
}

pub fn scope<F, R>(parallelism: usize, f: F) -> R
where
    F: for<'scope> FnOnce(Scope<'scope>) -> R,
{
    assert!(parallelism > 1);

    let state = &State {
        work: Cell::new(STOP),
        workers: parallelism - 1,
        pending: Padded(AtomicUsize::new(0)),
        generation: Padded(AtomicUsize::new(0)),
    };

    thread::scope(|scope| {
        for index in 1..parallelism {
            scope.spawn(move || state.worker(index));
        }

        struct StopGuard<'scope>(&'scope State);

        impl Drop for StopGuard<'_> {
            fn drop(&mut self) {
                let state = self.0;

                state.work.set(STOP);

                state.generation.fetch_add(1, Ordering::Release);
            }
        }

        let _guard = StopGuard(state);

        f(Scope { state })
    })
}

#[repr(align(128))]
struct State {
    work: Cell<&'static Work<'static>>,
    workers: usize,
    pending: Padded<AtomicUsize>,
    generation: Padded<AtomicUsize>,
}

unsafe impl Send for State {}

unsafe impl Sync for State {}

impl State {
    fn worker(&self, index: usize) {
        let mut last_generation = 0;

        loop {
            loop {
                let curr_generation = self.generation.load(Ordering::Acquire);

                if last_generation != curr_generation {
                    last_generation = curr_generation;
                    break;
                } else {
                    thread::yield_now();
                }
            }

            let mut stop = false;

            self.work.get()(index, &mut stop);

            if stop {
                return;
            }

            self.pending.fetch_sub(1, Ordering::Release);
        }
    }
}

type Work<'work> = dyn Fn(usize, &mut bool) + 'work;

const STOP: &Work = &|_index, stop: &mut bool| *stop = true;

#[repr(align(128))]
struct Padded<T>(T);

impl<T> Deref for Padded<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::sync::atomic::AtomicBool;

    #[test]
    fn it_works() {
        let parallelism = 4;

        let flags = (0..parallelism)
            .map(|_| AtomicBool::new(false))
            .collect::<Vec<_>>();

        scope(parallelism, |scope| {
            scope.broadcast(|index| {
                flags[index].store(true, Ordering::Relaxed);
            });
        });

        for flag in flags {
            assert!(flag.load(Ordering::Relaxed));
        }
    }
}

There is some other stuff in there, e.g. avoiding the manual trampoline management by relying on compiler-generated vtables, avoiding the second atomic load for the stop flag by broadcasting a STOP work item and RAII guards for unwind safety.

Also note that thread::scope currently internally and infallibly allocates a single Arc but that appears to be an implementation detail of the standard library. But then again, so does thread::Builder::spawn_unchecked, c.f. https://doc.rust-lang.org/stable/src/std/thread/mod.rs.html#512, which makes the whole allocation error checking discipline a bit moot ATM.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions