Skip to content

Commit 169e9e8

Browse files
gbaraldivchuravy
andauthored
Implement faster thread local rng for scheduler (#55501)
Implement optimal uniform random number generator using the method proposed in swiftlang/swift#39143 based on OpenSSL's implementation of it in https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 This PR also fixes some bugs found while developing it. This is a replacement for #50203 and fixes the issues found by @IanButterworth with both rngs C rng <img width="1011" alt="image" src="https://github.com/user-attachments/assets/0dd9d5f2-17ef-4a70-b275-1d12692be060"> New scheduler rng <img width="985" alt="image" src="https://github.com/user-attachments/assets/4abd0a57-a1d9-46ec-99a5-535f366ecafa"> ~On my benchmarks the julia implementation seems to be almost 50% faster than the current implementation.~ With oscars suggestion of removing the debiasing this is now almost 5x faster than the original implementation. And almost fully branchless We might want to backport the two previous commits since they technically fix bugs. --------- Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
1 parent 5272dad commit 169e9e8

File tree

6 files changed

+102
-10
lines changed

6 files changed

+102
-10
lines changed

base/partr.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,60 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
2020
const heaps_lock = [SpinLock(), SpinLock()]
2121

2222

23-
cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
23+
"""
24+
cong(max::UInt32)
25+
26+
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
27+
"""
28+
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
29+
30+
get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())
31+
32+
set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)
33+
34+
"""
35+
rand_ptls(max::UInt32)
36+
37+
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
38+
state. Max must be greater than 0.
39+
"""
40+
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
41+
rngseed = get_ptls_rng()
42+
val, seed = rand_uniform_max_int32(max, rngseed)
43+
set_ptls_rng(seed)
44+
return val % UInt32
45+
end
46+
47+
# This implementation is based on OpenSSLs implementation of rand_uniform
48+
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
49+
# Comments are vendored from their implementation as well.
50+
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
51+
52+
# Essentially it boils down to incrementally generating a fixed point
53+
# number on the interval [0, 1) and multiplying this number by the upper
54+
# range limit. Once it is certain what the fractional part contributes to
55+
# the integral part of the product, the algorithm has produced a definitive
56+
# result.
57+
"""
58+
rand_uniform_max_int32(max::UInt32, seed::UInt64)
59+
60+
Return a random UInt32 in the range `0:max-1` using the given seed.
61+
Max must be greater than 0.
62+
"""
63+
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
64+
if max == UInt32(1)
65+
return UInt32(0), seed
66+
end
67+
# We are generating a fixed point number on the interval [0, 1).
68+
# Multiplying this by the range gives us a number on [0, upper).
69+
# The high word of the multiplication result represents the integral part
70+
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
71+
seed = UInt64(69069) * seed + UInt64(362437)
72+
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
73+
i = prod >> 32 % UInt32 # integral part
74+
return i % UInt32, seed
75+
end
76+
2477

2578

2679
function multiq_sift_up(heap::taskheap, idx::Int32)

src/ccall.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ TRANSFORMED_CCALL_STAT(jl_cpu_wake);
2222
TRANSFORMED_CCALL_STAT(jl_gc_safepoint);
2323
TRANSFORMED_CCALL_STAT(jl_get_ptls_states);
2424
TRANSFORMED_CCALL_STAT(jl_threadid);
25+
TRANSFORMED_CCALL_STAT(jl_get_ptls_rng);
26+
TRANSFORMED_CCALL_STAT(jl_set_ptls_rng);
2527
TRANSFORMED_CCALL_STAT(jl_get_tls_world_age);
2628
TRANSFORMED_CCALL_STAT(jl_get_world_counter);
2729
TRANSFORMED_CCALL_STAT(jl_gc_enable_disable_finalizers_internal);
@@ -1692,6 +1694,36 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
16921694
ai.decorateInst(tid);
16931695
return mark_or_box_ccall_result(ctx, tid, retboxed, rt, unionall, static_rt);
16941696
}
1697+
else if (is_libjulia_func(jl_get_ptls_rng)) {
1698+
++CCALL_STAT(jl_get_ptls_rng);
1699+
assert(lrt == getInt64Ty(ctx.builder.getContext()));
1700+
assert(!isVa && !llvmcall && nccallargs == 0);
1701+
JL_GC_POP();
1702+
Value *ptls_p = get_current_ptls(ctx);
1703+
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
1704+
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
1705+
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
1706+
LoadInst *rng_value = ctx.builder.CreateAlignedLoad(getInt64Ty(ctx.builder.getContext()), rng_ptr, Align(sizeof(void*)));
1707+
setName(ctx.emission_context, rng_value, "rngseed");
1708+
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
1709+
ai.decorateInst(rng_value);
1710+
return mark_or_box_ccall_result(ctx, rng_value, retboxed, rt, unionall, static_rt);
1711+
}
1712+
else if (is_libjulia_func(jl_set_ptls_rng)) {
1713+
++CCALL_STAT(jl_set_ptls_rng);
1714+
assert(lrt == getVoidTy(ctx.builder.getContext()));
1715+
assert(!isVa && !llvmcall && nccallargs == 1);
1716+
JL_GC_POP();
1717+
Value *ptls_p = get_current_ptls(ctx);
1718+
const int rng_offset = offsetof(jl_tls_states_t, rngseed);
1719+
Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t)));
1720+
setName(ctx.emission_context, rng_ptr, "rngseed_ptr");
1721+
assert(argv[0].V->getType() == getInt64Ty(ctx.builder.getContext()));
1722+
auto store = ctx.builder.CreateAlignedStore(argv[0].V, rng_ptr, Align(sizeof(void*)));
1723+
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
1724+
ai.decorateInst(store);
1725+
return ghostValue(ctx, jl_nothing_type);
1726+
}
16951727
else if (is_libjulia_func(jl_get_tls_world_age)) {
16961728
bool toplevel = !(ctx.linfo && jl_is_method(ctx.linfo->def.method));
16971729
if (!toplevel) { // top level code does not see a stable world age during execution

src/jl_exported_funcs.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@
452452
XX(jl_test_cpu_feature) \
453453
XX(jl_threadid) \
454454
XX(jl_threadpoolid) \
455+
XX(jl_get_ptls_rng) \
456+
XX(jl_set_ptls_rng) \
455457
XX(jl_throw) \
456458
XX(jl_throw_out_of_memory_error) \
457459
XX(jl_too_few_args) \

src/julia_threads.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ extern "C" {
1818

1919
JL_DLLEXPORT int16_t jl_threadid(void);
2020
JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT;
21+
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT;
22+
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT;
2123

2224
// JULIA_ENABLE_THREADING may be controlled by altering JULIA_THREADS in Make.user
2325

src/scheduler.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,6 @@ JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSA
8484
extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache,
8585
jl_gc_markqueue_t *mq, jl_value_t *obj) JL_NOTSAFEPOINT;
8686

87-
// parallel task runtime
88-
// ---
89-
90-
JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n)
91-
{
92-
jl_ptls_t ptls = jl_current_task->ptls;
93-
return cong(max, &ptls->rngseed);
94-
}
95-
9687
// initialize the threading infrastructure
9788
// (called only by the main thread)
9889
void jl_init_threadinginfra(void)

src/threading.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,18 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT
314314
return -1; // everything else uses threadpool -1 (does not belong to any threadpool)
315315
}
316316

317+
// get thread local rng
318+
JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT
319+
{
320+
return jl_current_task->ptls->rngseed;
321+
}
322+
323+
// get thread local rng
324+
JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT
325+
{
326+
jl_current_task->ptls->rngseed = new_seed;
327+
}
328+
317329
jl_ptls_t jl_init_threadtls(int16_t tid)
318330
{
319331
#ifndef _OS_WINDOWS_

0 commit comments

Comments
 (0)