Skip to content

Commit 7708eb1

Browse files
authored
Merge pull request #25004 from JuliaLang/rf/rand/UInt52
RNG: delete rand_ui* functions in favor of multi-dispatch
2 parents 2da4c18 + 3c65a84 commit 7708eb1

File tree

5 files changed

+98
-48
lines changed

5 files changed

+98
-48
lines changed

base/random/RNGs.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ for T in (Bool, BitInteger_types...)
4242
end
4343
end
4444

45+
# RandomDevice produces natively UInt64
46+
rng_native_52(::RandomDevice) = UInt64
47+
4548
"""
4649
RandomDevice()
4750
@@ -54,10 +57,6 @@ RandomDevice
5457
RandomDevice(::Void) = RandomDevice()
5558
srand(rng::RandomDevice) = rng
5659

57-
### generation of floats
58-
59-
rand(r::RandomDevice, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])
60-
6160

6261
## MersenneTwister
6362

@@ -209,58 +208,59 @@ const GLOBAL_RNG = MersenneTwister(0)
209208

210209
### generation
211210

211+
# MersenneTwister produces natively Float64
212+
rng_native_52(::MersenneTwister) = Float64
213+
212214
#### helper functions
213215

214216
# precondition: !mt_empty(r)
215217
rand_inbounds(r::MersenneTwister, ::Close1Open2_64) = mt_pop!(r)
216-
rand_inbounds(r::MersenneTwister, ::CloseOpen_64) =
218+
rand_inbounds(r::MersenneTwister, ::CloseOpen_64=CloseOpen()) =
217219
rand_inbounds(r, Close1Open2()) - 1.0
218-
rand_inbounds(r::MersenneTwister) = rand_inbounds(r, CloseOpen())
219220

220-
rand_ui52_raw_inbounds(r::MersenneTwister) =
221-
reinterpret(UInt64, rand_inbounds(r, Close1Open2()))
222-
rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r))
221+
rand_inbounds(r::MersenneTwister, ::UInt52Raw{T}) where {T<:BitInteger} =
222+
reinterpret(UInt64, rand_inbounds(r, Close1Open2())) % T
223223

224-
function rand_ui2x52_raw(r::MersenneTwister)
225-
reserve(r, 2)
226-
rand_ui52_raw_inbounds(r) % UInt128 << 64 | rand_ui52_raw_inbounds(r)
224+
function rand(r::MersenneTwister, x::SamplerTrivial{UInt52Raw{UInt64}})
225+
reserve_1(r)
226+
rand_inbounds(r, x[])
227227
end
228228

229-
function rand_ui104_raw(r::MersenneTwister)
229+
function rand(r::MersenneTwister, ::SamplerTrivial{UInt2x52Raw{UInt128}})
230230
reserve(r, 2)
231-
rand_ui52_raw_inbounds(r) % UInt128 << 52 rand_ui52_raw_inbounds(r)
231+
rand_inbounds(r, UInt52Raw(UInt128)) << 64 | rand_inbounds(r, UInt52Raw(UInt128))
232232
end
233233

234-
rand_ui10_raw(r::MersenneTwister) = rand_ui52_raw(r)
235-
rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)
234+
function rand(r::MersenneTwister, ::SamplerTrivial{UInt104Raw{UInt128}})
235+
reserve(r, 2)
236+
rand_inbounds(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
237+
end
236238

237239
#### floats
238240

239-
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval_64}) =
241+
rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =
240242
(reserve_1(r); rand_inbounds(r, sp[]))
241243

242-
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])
243-
244244
#### integers
245245

246246
rand(r::MersenneTwister,
247247
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
248-
rand_ui52_raw(r) % T[]
248+
rand(r, UInt52Raw()) % T[]
249249

250250
function rand(r::MersenneTwister, ::SamplerType{UInt64})
251251
reserve(r, 2)
252-
rand_ui52_raw_inbounds(r) << 32 rand_ui52_raw_inbounds(r)
252+
rand_inbounds(r, UInt52Raw()) << 32 rand_inbounds(r, UInt52Raw())
253253
end
254254

255255
function rand(r::MersenneTwister, ::SamplerType{UInt128})
256256
reserve(r, 3)
257-
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
258-
rand_ui52_raw_inbounds(r) % UInt128 << 48,
259-
rand_ui52_raw_inbounds(r))
257+
xor(rand_inbounds(r, UInt52Raw(UInt128)) << 96,
258+
rand_inbounds(r, UInt52Raw(UInt128)) << 48,
259+
rand_inbounds(r, UInt52Raw(UInt128)))
260260
end
261261

262-
rand(r::MersenneTwister, ::SamplerType{Int64}) = reinterpret(Int64, rand(r, UInt64))
263-
rand(r::MersenneTwister, ::SamplerType{Int128}) = reinterpret(Int128, rand(r, UInt128))
262+
rand(r::MersenneTwister, ::SamplerType{Int64}) = rand(r, UInt64) % Int64
263+
rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
264264

265265
#### arrays of floats
266266

@@ -395,7 +395,7 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
395395
end
396396
end
397397
if n > 0
398-
u = rand_ui2x52_raw(r)
398+
u = rand(r, UInt2x52Raw())
399399
for i = 1:n
400400
@inbounds A[i] ⊻= u << (12*i)
401401
end

base/random/generation.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ Sampler(rng::AbstractRNG, ::Type{T}, n::Repetition) where {T<:AbstractFloat} =
2323
# generic random generation function which can be used by RNG implementors
2424
# it is not defined as a fallback rand method as this could create ambiguities
2525

26-
rand_generic(r::AbstractRNG, ::CloseOpen{Float16}) =
26+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen{Float16}}) =
2727
Float16(reinterpret(Float32,
28-
(rand_ui10_raw(r) % UInt32 << 13) & 0x007fe000 | 0x3f800000) - 1)
28+
(rand(r, UInt10(UInt32)) << 13) | 0x3f800000) - 1)
2929

30-
rand_generic(r::AbstractRNG, ::CloseOpen{Float32}) =
31-
reinterpret(Float32, rand_ui23_raw(r) % UInt32 & 0x007fffff | 0x3f800000) - 1
30+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen{Float32}}) =
31+
reinterpret(Float32, rand(r, UInt23()) | 0x3f800000) - 1
3232

33-
rand_generic(r::AbstractRNG, ::Close1Open2_64) =
34-
reinterpret(Float64, 0x3ff0000000000000 | rand(r, UInt64) & 0x000fffffffffffff)
33+
rand(r::AbstractRNG, ::SamplerTrivial{Close1Open2_64}) =
34+
reinterpret(Float64, 0x3ff0000000000000 | rand(r, UInt52()))
3535

36-
rand_generic(r::AbstractRNG, ::CloseOpen_64) = rand(r, Close1Open2()) - 1.0
36+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen_64}) = rand(r, Close1Open2()) - 1.0
3737

3838
#### BigFloat
3939

@@ -101,11 +101,21 @@ rand(rng::AbstractRNG, sp::SamplerBigFloat{T}) where {T<:FloatInterval{BigFloat}
101101

102102
### random integers
103103

104-
rand_ui10_raw(r::AbstractRNG) = rand(r, UInt16)
105-
rand_ui23_raw(r::AbstractRNG) = rand(r, UInt32)
104+
rand(r::AbstractRNG, ::SamplerTrivial{UInt10Raw{UInt16}}) = rand(r, UInt16)
105+
rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32)
106106

107-
rand_ui52_raw(r::AbstractRNG) = reinterpret(UInt64, rand(r, Close1Open2()))
108-
rand_ui52(r::AbstractRNG) = rand_ui52_raw(r) & 0x000fffffffffffff
107+
rand(r::AbstractRNG, ::SamplerTrivial{UInt52Raw{UInt64}}) =
108+
_rand52(r, rng_native_52(r))
109+
110+
_rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Open2()))
111+
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)
112+
113+
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
114+
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
115+
rand(r::AbstractRNG, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt52Raw()) & 0x000fffffffffffff
116+
117+
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
118+
rand(r, uint_default(sp[])) % T
109119

110120
### random complex numbers
111121

@@ -158,25 +168,28 @@ function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Int128,UInt128
158168
SamplerRangeFast{UInt128,T}(first(r), bw, m, mask)
159169
end
160170

161-
function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
171+
function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
162172
while true
163-
x = randfun(r) & mask
173+
x = rand(r, S) & mask
164174
x <= u && return x
165175
end
166176
end
167177

178+
# helper function, to turn types to values, should be removed once we can do rand(Uniform(UInt))
179+
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)
180+
168181
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
169182
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
170-
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m, mask) :
171-
rand_lteq(rng, rng->rand(rng, UInt64), m, mask)
183+
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m, mask) :
184+
rand_lteq(rng, Val(UInt64), m, mask)
172185
(x + a % UInt64) % T
173186
end
174187

175188
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt128,T}) where T
176189
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
177-
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m % UInt64, mask % UInt64) % UInt128 :
178-
bw <= 104 ? rand_lteq(rng, rand_ui104_raw, m, mask) :
179-
rand_lteq(rng, rng->rand(rng, UInt128), m, mask)
190+
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m % UInt64, mask % UInt64) % UInt128 :
191+
bw <= 104 ? rand_lteq(rng, UInt104Raw(), m, mask) :
192+
rand_lteq(rng, Val(UInt128), m, mask)
180193
x % T + a
181194
end
182195

base/random/misc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
147147
@inline function rand_lt(r::AbstractRNG, n::Int, mask::Int=nextpow2(n)-1)
148148
# this duplicates the functionality of rand(1:n), to optimize this special case
149149
while true
150-
x = (rand_ui52_raw(r) % Int) & mask
150+
x = rand(r, UInt52Raw(Int)) & mask
151151
x < n && return x
152152
end
153153
end

base/random/normal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ julia> randn(rng, Complex64, (2, 3))
3535
"""
3636
@inline function randn(rng::AbstractRNG=GLOBAL_RNG)
3737
@inbounds begin
38-
r = rand_ui52(rng)
38+
r = rand(rng, UInt52())
3939
rabs = Int64(r>>1) # One bit for the sign
4040
idx = rabs & 0xFF
4141
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
@@ -95,7 +95,7 @@ julia> randexp(rng, 3, 3)
9595
"""
9696
function randexp(rng::AbstractRNG=GLOBAL_RNG)
9797
@inbounds begin
98-
ri = rand_ui52(rng)
98+
ri = rand(rng, UInt52())
9999
idx = ri & 0xFF
100100
x = ri*we[idx+1]
101101
ri < ke[idx+1] && return x # 98.9% of the time we return here 1st try

base/random/random.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,43 @@ export srand,
2525

2626
abstract type AbstractRNG end
2727

28+
### integers
29+
30+
# we define types which encode the generation of a specific number of bits
31+
# the "raw" version means that the unused bits are not zeroed
32+
33+
abstract type UniformBits{T<:BitInteger} end
34+
35+
struct UInt10{T} <: UniformBits{T} end
36+
struct UInt10Raw{T} <: UniformBits{T} end
37+
38+
struct UInt23{T} <: UniformBits{T} end
39+
struct UInt23Raw{T} <: UniformBits{T} end
40+
41+
struct UInt52{T} <: UniformBits{T} end
42+
struct UInt52Raw{T} <: UniformBits{T} end
43+
44+
struct UInt104{T} <: UniformBits{T} end
45+
struct UInt104Raw{T} <: UniformBits{T} end
46+
47+
struct UInt2x52{T} <: UniformBits{T} end
48+
struct UInt2x52Raw{T} <: UniformBits{T} end
49+
50+
uint_sup(::Type{<:Union{UInt10,UInt10Raw}}) = UInt16
51+
uint_sup(::Type{<:Union{UInt23,UInt23Raw}}) = UInt32
52+
uint_sup(::Type{<:Union{UInt52,UInt52Raw}}) = UInt64
53+
uint_sup(::Type{<:Union{UInt104,UInt104Raw}}) = UInt128
54+
uint_sup(::Type{<:Union{UInt2x52,UInt2x52Raw}}) = UInt128
55+
56+
for UI = (:UInt10, :UInt10Raw, :UInt23, :UInt23Raw, :UInt52, :UInt52Raw,
57+
:UInt104, :UInt104Raw, :UInt2x52, :UInt2x52Raw)
58+
@eval begin
59+
$UI(::Type{T}=uint_sup($UI)) where {T} = $UI{T}()
60+
# useful for defining rand generically:
61+
uint_default(::$UI) = $UI{uint_sup($UI)}()
62+
end
63+
end
64+
2865
### floats
2966

3067
abstract type FloatInterval{T<:AbstractFloat} end

0 commit comments

Comments
 (0)