Skip to content

Commit 8f675fe

Browse files
committed
Host-side RNG interface.
1 parent 5b50809 commit 8f675fe

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

src/random.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,92 @@ using Random
44

55
export rand_logn!, rand_poisson!
66

7+
8+
# native RNG
9+
10+
struct RNG <: AbstractRNG
11+
state::CuVector{UInt32}
12+
13+
function RNG(seed)
14+
@assert length(seed) == 32
15+
new(seed)
16+
end
17+
end
18+
19+
RNG() = RNG(rand(UInt32, 32))
20+
21+
function Random.seed!(rng::RNG, seed::AbstractVector{UInt32})
22+
@assert length(seed) == 32
23+
copyto!(rng.state, seed)
24+
end
25+
26+
function Random.seed!(rng::RNG)
27+
Random.rand!(rng.state)
28+
return
29+
end
30+
31+
function Random.rand!(rng::RNG, A::AnyCuArray)
32+
function kernel(a::AbstractArray{T}, state::AbstractVector{UInt32}) where {T}
33+
device_rng = Random.default_rng()
34+
35+
# initialize the state
36+
tid = threadIdx().x
37+
if tid <= 32
38+
# we know for sure the seed will contain 32 values
39+
@inbounds Random.seed!(device_rng, state)
40+
end
41+
42+
sync_threads()
43+
44+
# grid-stride loop
45+
offset = (blockIdx().x - 1) * blockDim().x
46+
while offset < length(A)
47+
# generating random numbers synchronizes threads, so needs to happen uniformly.
48+
val = Random.rand(device_rng, T)
49+
50+
i = tid + offset
51+
if i <= length(a)
52+
@inbounds a[i] = val
53+
end
54+
55+
offset += (blockDim().x - 1) * gridDim().x
56+
end
57+
58+
sync_threads()
59+
60+
# save the device rng state of the first block (other blocks are derived from it)
61+
# so that subsequent launches generate different random numbers
62+
if blockIdx().x == 1 && tid <= 32
63+
@inbounds state[tid] = device_rng.state[tid]
64+
end
65+
66+
return
67+
end
68+
69+
kernel = @cuda launch=false name="rand!" kernel(A, rng.state)
70+
config = launch_configuration(kernel.fun; max_threads=64)
71+
threads = max(32, min(config.threads, length(A)))
72+
blocks = min(config.blocks, cld(length(A), threads))
73+
kernel(A, rng.state; threads=threads, blocks=blocks)
74+
75+
# XXX: updating the state from within the kernel is racey,
76+
# so we should have a RNG per stream
77+
78+
A
79+
end
80+
81+
function Random.rand(rng::RNG, T::Type)
82+
assertscalar("scalar rand")
83+
A = CuArray{T}(undef, 1)
84+
Random.rand!(rng, A)
85+
A[]
86+
end
87+
88+
# TODO: `randn!`; cannot reuse from Base or RandomNumbers, as those do scalar indexing
89+
90+
91+
# RNG-less interface
92+
793
# the interface is split in two levels:
894
# - functions that extend the Random standard library, and take an RNG as first argument,
995
# will only ever dispatch to CURAND and as a result are limited in the types they support.

0 commit comments

Comments
 (0)