@@ -4,6 +4,92 @@ using Random
4
4
5
5
export rand_logn!, rand_poisson!
6
6
7
+
8
+ # native RNG
9
+
10
+ struct RNG <: AbstractRNG
11
+ state:: CuVector{UInt32}
12
+
13
+ function RNG (seed)
14
+ @assert length (seed) == 32
15
+ new (seed)
16
+ end
17
+ end
18
+
19
+ RNG () = RNG (rand (UInt32, 32 ))
20
+
21
+ function Random. seed! (rng:: RNG , seed:: AbstractVector{UInt32} )
22
+ @assert length (seed) == 32
23
+ copyto! (rng. state, seed)
24
+ end
25
+
26
+ function Random. seed! (rng:: RNG )
27
+ Random. rand! (rng. state)
28
+ return
29
+ end
30
+
31
+ function Random. rand! (rng:: RNG , A:: AnyCuArray )
32
+ function kernel (a:: AbstractArray{T} , state:: AbstractVector{UInt32} ) where {T}
33
+ device_rng = Random. default_rng ()
34
+
35
+ # initialize the state
36
+ tid = threadIdx (). x
37
+ if tid <= 32
38
+ # we know for sure the seed will contain 32 values
39
+ @inbounds Random. seed! (device_rng, state)
40
+ end
41
+
42
+ sync_threads ()
43
+
44
+ # grid-stride loop
45
+ offset = (blockIdx (). x - 1 ) * blockDim (). x
46
+ while offset < length (A)
47
+ # generating random numbers synchronizes threads, so needs to happen uniformly.
48
+ val = Random. rand (device_rng, T)
49
+
50
+ i = tid + offset
51
+ if i <= length (a)
52
+ @inbounds a[i] = val
53
+ end
54
+
55
+ offset += (blockDim (). x - 1 ) * gridDim (). x
56
+ end
57
+
58
+ sync_threads ()
59
+
60
+ # save the device rng state of the first block (other blocks are derived from it)
61
+ # so that subsequent launches generate different random numbers
62
+ if blockIdx (). x == 1 && tid <= 32
63
+ @inbounds state[tid] = device_rng. state[tid]
64
+ end
65
+
66
+ return
67
+ end
68
+
69
+ kernel = @cuda launch= false name= " rand!" kernel (A, rng. state)
70
+ config = launch_configuration (kernel. fun; max_threads= 64 )
71
+ threads = max (32 , min (config. threads, length (A)))
72
+ blocks = min (config. blocks, cld (length (A), threads))
73
+ kernel (A, rng. state; threads= threads, blocks= blocks)
74
+
75
+ # XXX : updating the state from within the kernel is racey,
76
+ # so we should have a RNG per stream
77
+
78
+ A
79
+ end
80
+
81
+ function Random. rand (rng:: RNG , T:: Type )
82
+ assertscalar (" scalar rand" )
83
+ A = CuArray {T} (undef, 1 )
84
+ Random. rand! (rng, A)
85
+ A[]
86
+ end
87
+
88
+ # TODO : `randn!`; cannot reuse from Base or RandomNumbers, as those do scalar indexing
89
+
90
+
91
+ # RNG-less interface
92
+
7
93
# the interface is split in two levels:
8
94
# - functions that extend the Random standard library, and take an RNG as first argument,
9
95
# will only ever dispatch to CURAND and as a result are limited in the types they support.
0 commit comments