Skip to content

Commit ac41200

Browse files
committed
Merge branch 'pr/feat-worker-yield' into main-dev
2 parents 1e9540f + 541842c commit ac41200

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

include/fork_union.hpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,19 @@ class unique_padded_buffer {
588588
*/
589589
struct dummy_lambda_t {};
590590

591+
template <typename yield_type_, typename thread_index_type_>
592+
struct yield_traits {
593+
static constexpr bool supports_no_arg = std::is_nothrow_invocable_r_v<void, yield_type_>;
594+
static constexpr bool supports_thread_index = std::is_nothrow_invocable_r_v<void, yield_type_, thread_index_type_>;
595+
static constexpr bool valid = supports_no_arg || supports_thread_index;
596+
};
597+
598+
template <typename yield_type_, typename thread_index_type_>
599+
inline void call_yield_(yield_type_ &yield, thread_index_type_ thread_index) noexcept {
600+
if constexpr (yield_traits<yield_type_, thread_index_type_>::supports_thread_index) { yield(thread_index); }
601+
else { yield(); }
602+
}
603+
591604
/**
592605
* @brief A trivial minimalistic lock-free "mutex" implementation using `std::atomic_flag`.
593606
* @tparam micro_yield_type_ The type of the yield function to be used for busy-waiting.
@@ -610,7 +623,7 @@ class spin_mutex {
610623
public:
611624
void lock() noexcept {
612625
micro_yield_t micro_yield;
613-
while (flag_.test_and_set(std::memory_order_acquire)) micro_yield();
626+
while (flag_.test_and_set(std::memory_order_acquire)) call_yield_(micro_yield);
614627
}
615628
bool try_lock() noexcept { return !flag_.test_and_set(std::memory_order_acquire); }
616629
void unlock() noexcept { flag_.clear(std::memory_order_release); }
@@ -634,7 +647,7 @@ class spin_mutex {
634647
public:
635648
void lock() noexcept {
636649
micro_yield_t micro_yield;
637-
while (flag_.exchange(true, std::memory_order_acquire)) micro_yield();
650+
while (flag_.exchange(true, std::memory_order_acquire)) call_yield_(micro_yield);
638651
}
639652
bool try_lock() noexcept { return !flag_.exchange(true, std::memory_order_acquire); }
640653
void unlock() noexcept { flag_.store(false, std::memory_order_release); }
@@ -1022,8 +1035,6 @@ class basic_pool {
10221035
public:
10231036
using allocator_t = allocator_type_;
10241037
using micro_yield_t = micro_yield_type_;
1025-
static_assert(std::is_nothrow_invocable_r<void, micro_yield_t>::value,
1026-
"Yield must be callable w/out arguments & return void");
10271038
static constexpr std::size_t alignment_k = alignment_;
10281039
static_assert(is_power_of_two(alignment_k), "Alignment must be a power of 2");
10291040

@@ -1039,6 +1050,9 @@ class basic_pool {
10391050
using punned_fork_context_t = void *; // ? Pointer to the on-stack lambda
10401051
using trampoline_t = void (*)(punned_fork_context_t, thread_index_t); // ? Wraps lambda's `operator()`
10411052

1053+
using micro_yield_traits_t = yield_traits<micro_yield_t, thread_index_t>;
1054+
static_assert(micro_yield_traits_t::valid, "Yield must be invocable w/out args or with a thread index");
1055+
10421056
private:
10431057
// Thread-pool-specific variables:
10441058
allocator_t allocator_ {};
@@ -1234,7 +1248,8 @@ class basic_pool {
12341248

12351249
// Actually wait for everyone to finish
12361250
micro_yield_t micro_yield;
1237-
while (threads_to_sync_.load(std::memory_order_acquire)) micro_yield();
1251+
while (threads_to_sync_.load(std::memory_order_acquire))
1252+
call_yield_(micro_yield, static_cast<thread_index_t>(0));
12381253
}
12391254

12401255
#pragma endregion Core API
@@ -1420,7 +1435,7 @@ class basic_pool {
14201435
micro_yield_t micro_yield;
14211436
while ((new_epoch = epoch_.load(std::memory_order_acquire)) == last_epoch &&
14221437
(mood = mood_.load(std::memory_order_acquire)) == mood_t::grind_k)
1423-
micro_yield();
1438+
call_yield_(micro_yield, thread_index);
14241439

14251440
if (fu_unlikely_(mood == mood_t::die_k)) break;
14261441
if (fu_unlikely_(mood == mood_t::chill_k) && (new_epoch == last_epoch)) {
@@ -2534,8 +2549,6 @@ struct linux_colocated_pool {
25342549
public:
25352550
using allocator_t = linux_numa_allocator_t;
25362551
using micro_yield_t = micro_yield_type_;
2537-
static_assert(std::is_nothrow_invocable_r<void, micro_yield_t>::value,
2538-
"Yield must be callable w/out arguments & return void");
25392552
static constexpr std::size_t alignment_k = alignment_;
25402553
static_assert(alignment_k > 0 && (alignment_k & (alignment_k - 1)) == 0, "Alignment must be a power of 2");
25412554

@@ -2549,6 +2562,9 @@ struct linux_colocated_pool {
25492562
using punned_fork_context_t = void *; // ? Pointer to the on-stack lambda
25502563
using trampoline_t = void (*)(punned_fork_context_t, colocated_thread_t); // ? Wraps lambda's `operator()`
25512564

2565+
using micro_yield_traits_t = yield_traits<micro_yield_t, thread_index_t>;
2566+
static_assert(micro_yield_traits_t::valid, "Yield must be invocable w/out args or with a thread index");
2567+
25522568
private:
25532569
using allocator_traits_t = std::allocator_traits<allocator_t>;
25542570
using numa_pthread_allocator_t = typename allocator_traits_t::template rebind_alloc<numa_pthread_t>;
@@ -2878,7 +2894,8 @@ struct linux_colocated_pool {
28782894

28792895
// Actually wait for everyone to finish
28802896
micro_yield_t micro_yield;
2881-
while (threads_to_sync_.load(std::memory_order_acquire)) micro_yield();
2897+
while (threads_to_sync_.load(std::memory_order_acquire))
2898+
call_yield_(micro_yield, static_cast<thread_index_t>(0));
28822899
}
28832900

28842901
#pragma endregion Core API
@@ -3088,7 +3105,9 @@ struct linux_colocated_pool {
30883105
// so spin-loop for a bit until the pool is ready.
30893106
mood_t mood;
30903107
micro_yield_t micro_yield;
3091-
while ((mood = pool->mood_.load(std::memory_order_acquire)) == mood_t::chill_k) micro_yield();
3108+
while ((mood = pool->mood_.load(std::memory_order_acquire)) == mood_t::chill_k)
3109+
// Technically, we are not on the zero thread index, but we don't know our index yet.
3110+
call_yield_(micro_yield, static_cast<thread_index_t>(0));
30923111

30933112
// If we are ready to start grinding, export this threads metadata to make it externally
30943113
// observable and controllable.
@@ -3123,7 +3142,7 @@ struct linux_colocated_pool {
31233142
// Wait for either: a new ticket or a stop flag
31243143
while ((new_epoch = pool->epoch_.load(std::memory_order_acquire)) == last_epoch &&
31253144
(mood = pool->mood_.load(std::memory_order_acquire)) == mood_t::grind_k)
3126-
micro_yield();
3145+
call_yield_(micro_yield, global_thread_index);
31273146

31283147
if (fu_unlikely_(mood == mood_t::die_k)) break;
31293148
if (fu_unlikely_(mood == mood_t::chill_k) && (new_epoch == last_epoch)) {

0 commit comments

Comments
 (0)