Skip to content

Commit 770b53d

Browse files
authored
Merge pull request #25047 from JuliaLang/rf/rand/range-52
faster rand!(::MersenneTwister, ::UnitRange)
2 parents 0649d0e + 5bb6b4d commit 770b53d

File tree

4 files changed

+127
-85
lines changed

4 files changed

+127
-85
lines changed

base/random/generation.jl

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt52Raw{UInt64}}) =
110110
_rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Open2()))
111111
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)
112112

113-
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
114-
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
115-
rand(r::AbstractRNG, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt52Raw()) & 0x000fffffffffffff
113+
rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
114+
rand(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
115+
116+
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
117+
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
118+
rand(r::AbstractRNG, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt52Raw()) & 0x000fffffffffffff
119+
rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw()) & 0x000000ffffffffffffffffffffffffff
116120

117121
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
118122
rand(r, uint_default(sp[])) % T
@@ -143,6 +147,12 @@ end
143147
# 2) "Default" which tries to use as few entropy bits as possible, at the cost of a
144148
# a bigger upfront price associated with the creation of the sampler
145149

150+
#### helper functions
151+
152+
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
153+
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
154+
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
155+
146156
#### Fast
147157

148158
struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
@@ -152,44 +162,37 @@ struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
152162
mask::U # mask generated values before threshold rejection
153163
end
154164

155-
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
156-
isempty(r) && throw(ArgumentError("range must be non-empty"))
157-
m = last(r) % UInt64 - first(r) % UInt64
158-
bw = (64 - leading_zeros(m)) % UInt # bit-width
159-
mask = (1 % UInt64 << bw) - (1 % UInt64)
160-
SamplerRangeFast{UInt64,T}(first(r), bw, m, mask)
161-
end
165+
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
166+
SamplerRangeFast(r, uint_sup(T))
162167

163-
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Int128,UInt128}
168+
function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
164169
isempty(r) && throw(ArgumentError("range must be non-empty"))
165-
m = (last(r)-first(r)) % UInt128
166-
bw = (128 - leading_zeros(m)) % UInt # bit-width
167-
mask = (1 % UInt128 << bw) - (1 % UInt128)
168-
SamplerRangeFast{UInt128,T}(first(r), bw, m, mask)
170+
m = (last(r) - first(r)) % U
171+
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
172+
mask = (1 % U << bw) - (1 % U)
173+
SamplerRangeFast{U,T}(first(r), bw, m, mask)
169174
end
170175

171-
function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
172-
while true
173-
x = rand(r, S) & mask
174-
x <= u && return x
175-
end
176+
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
177+
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
178+
x = rand(rng, LessThan(m, Masked(mask, uniform(UInt32))))
179+
(x + a % UInt32) % T
176180
end
177181

178-
# helper function, to turn types to values, should be removed once we can do rand(Uniform(UInt))
179-
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)
180-
181182
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
182183
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
183-
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m, mask) :
184-
rand_lteq(rng, Val(UInt64), m, mask)
184+
x = bw <= 52 ? rand(rng, LessThan(m, Masked(mask, UInt52Raw()))) :
185+
rand(rng, LessThan(m, Masked(mask, uniform(UInt64))))
185186
(x + a % UInt64) % T
186187
end
187188

188189
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt128,T}) where T
189190
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
190-
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m % UInt64, mask % UInt64) % UInt128 :
191-
bw <= 104 ? rand_lteq(rng, UInt104Raw(), m, mask) :
192-
rand_lteq(rng, Val(UInt128), m, mask)
191+
x = bw <= 52 ?
192+
rand(rng, LessThan(m % UInt64, Masked(mask % UInt64, UInt52Raw()))) % UInt128 :
193+
bw <= 104 ?
194+
rand(rng, LessThan(m, Masked(mask, UInt104Raw()))) :
195+
rand(rng, LessThan(m, Masked(mask, uniform(UInt128))))
193196
x % T + a
194197
end
195198

@@ -199,60 +202,70 @@ end
199202
rem_knuth(a::UInt, b::UInt) = a % (b + (b == 0)) + a * (b == 0)
200203
rem_knuth(a::T, b::T) where {T<:Unsigned} = b != 0 ? a % b : a
201204

202-
# maximum multiple of k <= 2^bits(T) decremented by one,
203-
# that is 0xFFFF...FFFF if k = typemax(T) - typemin(T) with intentional underflow
205+
# maximum multiple of k <= sup decremented by one,
206+
# that is 0xFFFF...FFFF if k = (typemax(T) - typemin(T)) + 1 and sup == typemax(T) - 1
207+
# with intentional underflow
204208
# see http://stackoverflow.com/questions/29182036/integer-arithmetic-add-1-to-uint-max-and-divide-by-n-without-overflow
205-
maxmultiple(k::T) where {T<:Unsigned} =
206-
(div(typemax(T) - k + one(k), k + (k == 0))*k + k - one(k))::T
207209

208-
# serves as rejection threshold
209-
_maxmultiple(k) = maxmultiple(k)
210+
# sup == 0 means typemax(T) + 1
211+
maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
212+
(div(sup - k, k + (k == 0))*k + k - one(k))::T
210213

211-
# maximum multiple of k within 1:2^32 or 1:2^64 decremented by one, depending on size
212-
_maxmultiple(k::UInt64)::UInt64 = k >> 32 != 0 ?
213-
maxmultiple(k) :
214-
div(0x0000000100000000, k + (k == 0))*k - one(k)
214+
# similar but sup must not be equal to typemax(T)
215+
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
216+
div(sup, k + (k == 0))*k - one(k)
215217

216218
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
217-
a::T # first element of the range
218-
k::U # range length or zero for full range
219-
u::U # rejection threshold
219+
a::T # first element of the range
220+
bw::Int # bit width
221+
k::U # range length or zero for full range
222+
u::U # rejection threshold
220223
end
221224

222-
function SamplerRangeInt(a::T, diff::U) where {T<:Union{Bool,Integer},U<:Unsigned}
223-
k = diff+one(U)
224-
SamplerRangeInt{T,U}(a, k, _maxmultiple(k)) # overflow ok
225-
end
226225

227-
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
228-
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
229-
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
226+
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
227+
SamplerRangeInt(r, uint_sup(T))
230228

231-
function SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger}
229+
function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
232230
isempty(r) && throw(ArgumentError("range must be non-empty"))
233-
SamplerRangeInt(first(r), (last(r) - first(r)) % uint_sup(T))
231+
a = first(r)
232+
m = (last(r) - first(r)) % U
233+
k = m + one(U)
234+
bw = (sizeof(U) << 3 - leading_zeros(m)) % Int
235+
mult = if U === UInt32
236+
maxmultiple(k)
237+
elseif U === UInt64
238+
bw <= 52 ? unsafe_maxmultiple(k, one(UInt64) << 52) :
239+
maxmultiple(k)
240+
else # U === UInt128
241+
bw <= 52 ? unsafe_maxmultiple(k, one(UInt128) << 52) :
242+
bw <= 104 ? unsafe_maxmultiple(k, one(UInt128) << 104) :
243+
maxmultiple(k)
244+
end
245+
246+
SamplerRangeInt{T,U}(a, bw, k, mult) # overflow ok
234247
end
235248

236-
Sampler(::AbstractRNG, r::AbstractUnitRange{T}, ::Repetition) where {T<:Union{Bool,BitInteger}} =
237-
SamplerRangeInt(r)
249+
Sampler(::AbstractRNG, r::AbstractUnitRange{T},
250+
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
251+
252+
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
253+
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T
238254

239-
function rand_lteq(rng::AbstractRNG, u::T)::T where T
240-
while true
241-
x = rand(rng, T)
242-
x <= u && return x
243-
end
244-
end
245255

246-
# this function uses 32 bit entropy for small ranges of length <= typemax(UInt32) + 1
256+
# this function uses 52 bit entropy for small ranges of length <= 2^52
247257
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
248-
x::UInt64 = (sp.k - 1) >> 32 == 0 ?
249-
rand_lteq(rng, sp.u % UInt32) % UInt64 :
250-
rand_lteq(rng, sp.u)
258+
x = sp.bw <= 52 ? rand(rng, LessThan(sp.u, UInt52())) :
259+
rand(rng, LessThan(sp.u, uniform(UInt64)))
251260
return ((sp.a % UInt64) + rem_knuth(x, sp.k)) % T
252261
end
253262

254-
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,U}) where {T<:Union{Bool,BitInteger},U} =
255-
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, sp.u), sp.k)) % T
263+
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt128}) where T<:BitInteger
264+
x = sp.bw <= 52 ? rand(rng, LessThan(sp.u, UInt52(UInt128))) :
265+
sp.bw <= 104 ? rand(rng, LessThan(sp.u, UInt104(UInt128))) :
266+
rand(rng, LessThan(sp.u, uniform(UInt128)))
267+
return ((sp.a % UInt128) + rem_knuth(x, sp.k)) % T
268+
end
256269

257270

258271
### BigInt

base/random/misc.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,10 @@ large.) Technically, this process is known as "Bernoulli sampling" of `A`.
141141
randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
142142

143143

144-
## rand_lt (helper function)
145-
146-
"Return a random `Int` (masked with `mask`) in ``[0, n)``, when `n <= 2^52`."
147-
@inline function rand_lt(r::AbstractRNG, n::Int, mask::Int=nextpow2(n)-1)
148-
# this duplicates the functionality of rand(1:n), to optimize this special case
149-
while true
150-
x = rand(r, UInt52Raw(Int)) & mask
151-
x < n && return x
152-
end
153-
end
144+
## rand Less Than Masked 52 bits (helper function)
154145

146+
"Return a sampler generating a random `Int` (masked with `mask`) in ``[0, n)``, when `n <= 2^52`."
147+
ltm52(n::Int, mask::Int=nextpow2(n)-1) = LessThan(n-1, Masked(mask, UInt52Raw(Int)))
155148

156149
## shuffle & shuffle!
157150

@@ -191,7 +184,7 @@ function shuffle!(r::AbstractRNG, a::AbstractArray)
191184
mask = nextpow2(n) - 1
192185
for i = n:-1:2
193186
(mask >> 1) == i && (mask >>= 1)
194-
j = 1 + rand_lt(r, i, mask)
187+
j = 1 + rand(r, ltm52(i, mask))
195188
a[i], a[j] = a[j], a[i]
196189
end
197190
return a
@@ -277,7 +270,7 @@ function randperm!(r::AbstractRNG, a::Array{<:Integer})
277270
a[1] = 1
278271
mask = 3
279272
@inbounds for i = 2:n
280-
j = 1 + rand_lt(r, i, mask)
273+
j = 1 + rand(r, ltm52(i, mask))
281274
if i != j # a[i] is uninitialized (and could be #undef)
282275
a[i] = a[j]
283276
end
@@ -339,7 +332,7 @@ function randcycle!(r::AbstractRNG, a::Array{<:Integer})
339332
a[1] = 1
340333
mask = 3
341334
@inbounds for i = 2:n
342-
j = 1 + rand_lt(r, i-1, mask)
335+
j = 1 + rand(r, ltm52(i-1, mask))
343336
a[i] = a[j]
344337
a[j] = i
345338
i == 1+mask && (mask = 2mask + 1)

base/random/random.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Sampler(rng::AbstractRNG, sp::Sampler, ::Repetition) =
106106
Sampler(rng::AbstractRNG, X) = Sampler(rng, X, Val(Inf))
107107
Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
108108

109-
#### pre-defined useful Sampler subtypes
109+
#### pre-defined useful Sampler types
110110

111111
# default fall-back for types
112112
struct SamplerType{T} <: Sampler end
@@ -139,6 +139,38 @@ struct SamplerTag{T,S} <: Sampler
139139
end
140140

141141

142+
#### helper samplers
143+
144+
##### Adapter to generate a randome value in [0, n]
145+
146+
struct LessThan{T<:Integer,S} <: Sampler
147+
sup::T
148+
s::S # the scalar specification/sampler to feed to rand
149+
end
150+
151+
function rand(rng::AbstractRNG, sp::LessThan)
152+
while true
153+
x = rand(rng, sp.s)
154+
x <= sp.sup && return x
155+
end
156+
end
157+
158+
struct Masked{T<:Integer,S} <: Sampler
159+
mask::T
160+
s::S
161+
end
162+
163+
rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask
164+
165+
##### Uniform
166+
167+
struct UniformT{T} <: Sampler end
168+
169+
uniform(::Type{T}) where {T} = UniformT{T}()
170+
171+
rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
172+
173+
142174
### machinery for generation with Sampler
143175

144176
# This describes how to generate random scalars or arrays, by generating a Sampler

test/random.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ if sizeof(Int32) < sizeof(Int)
102102
local r = rand(Int32(-1):typemax(Int32))
103103
@test typeof(r) == Int32
104104
@test -1 <= r <= typemax(Int32)
105-
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(UInt64,1:k)).u for k in 13 .+ Int64(2).^(32:62)])
106-
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(Int64,1:k)).u for k in 13 .+ Int64(2).^(32:61)])
105+
for U = (Int64, UInt64)
106+
@test all(div(one(UInt128) << 52, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
107+
for k in 13 .+ Int64(2).^(32:51))
108+
@test all(div(one(UInt128) << 64, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
109+
for k in 13 .+ Int64(2).^(52:62))
110+
end
107111

108-
@test Base.Random._maxmultiple(0x000100000000) === 0xffffffffffffffff
109-
@test Base.Random._maxmultiple(0x0000FFFFFFFF) === 0x00000000fffffffe
110-
@test Base.Random._maxmultiple(0x000000000000) === 0xffffffffffffffff
111112
end
112113

113114
# BigInt specific
@@ -224,8 +225,11 @@ guardsrand() do
224225
@test r == rand(map(UInt64, 97:122))
225226
end
226227

227-
@test all([div(0x000100000000,k)*k - 1 == Base.Random.RangeGenerator(map(UInt64,1:k)).u for k in 13 .+ Int64(2).^(1:30)])
228-
@test all([div(0x000100000000,k)*k - 1 == Base.Random.RangeGenerator(map(Int64,1:k)).u for k in 13 .+ Int64(2).^(1:30)])
228+
for U in (Int64, UInt64)
229+
@test all(div(one(UInt64) << 52, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
230+
for k in 13 .+ Int64(2).^(1:30))
231+
end
232+
229233

230234
import Base.Random: uuid1, uuid4, UUID, uuid_version
231235

0 commit comments

Comments
 (0)