Skip to content

Commit c384517

Browse files
committed
faster rand(::SamplerRangeInt)
This sampler was based on the idea that generating a UInt32 value is cheaper than generating a UInt64 one. Historically, this is based on the performance characteristics of MersenneTwister. Here, we refine this idea, knowing that this RNG produces natively 52 bits of entropy. So the thresholds determining how native values are generated each time are 52 bits and 104 bits. In particular, this becomes much more efficient for small (of length <= 2^52) Int128 ranges, for which previously 3 native values had to be generated instead of one.
1 parent 7708eb1 commit c384517

File tree

2 files changed

+68
-40
lines changed

2 files changed

+68
-40
lines changed

base/random/generation.jl

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt52Raw{UInt64}}) =
110110
_rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Open2()))
111111
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)
112112

113-
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
114-
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
115-
rand(r::AbstractRNG, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt52Raw()) & 0x000fffffffffffff
113+
rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
114+
rand(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
115+
116+
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
117+
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
118+
rand(r::AbstractRNG, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt52Raw()) & 0x000fffffffffffff
119+
rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw()) & 0x000000ffffffffffffffffffffffffff
116120

117121
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
118122
rand(r, uint_default(sp[])) % T
@@ -199,60 +203,80 @@ end
199203
rem_knuth(a::UInt, b::UInt) = a % (b + (b == 0)) + a * (b == 0)
200204
rem_knuth(a::T, b::T) where {T<:Unsigned} = b != 0 ? a % b : a
201205

202-
# maximum multiple of k <= 2^bits(T) decremented by one,
203-
# that is 0xFFFF...FFFF if k = typemax(T) - typemin(T) with intentional underflow
206+
# maximum multiple of k <= sup decremented by one,
207+
# that is 0xFFFF...FFFF if k = (typemax(T) - typemin(T)) + 1 and sup == typemax(T) - 1
208+
# with intentional underflow
204209
# see http://stackoverflow.com/questions/29182036/integer-arithmetic-add-1-to-uint-max-and-divide-by-n-without-overflow
205-
maxmultiple(k::T) where {T<:Unsigned} =
206-
(div(typemax(T) - k + one(k), k + (k == 0))*k + k - one(k))::T
207210

208-
# serves as rejection threshold
209-
_maxmultiple(k) = maxmultiple(k)
211+
# sup == 0 means typemax(T) + 1
212+
maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
213+
(div(sup - k, k + (k == 0))*k + k - one(k))::T
210214

211-
# maximum multiple of k within 1:2^32 or 1:2^64 decremented by one, depending on size
212-
_maxmultiple(k::UInt64)::UInt64 = k >> 32 != 0 ?
213-
maxmultiple(k) :
214-
div(0x0000000100000000, k + (k == 0))*k - one(k)
215+
# similar but sup must not be equal to typemax(T)
216+
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
217+
div(sup, k + (k == 0))*k - one(k)
215218

216-
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
217-
a::T # first element of the range
218-
k::U # range length or zero for full range
219-
u::U # rejection threshold
220-
end
221219

222-
function SamplerRangeInt(a::T, diff::U) where {T<:Union{Bool,Integer},U<:Unsigned}
223-
k = diff+one(U)
224-
SamplerRangeInt{T,U}(a, k, _maxmultiple(k)) # overflow ok
220+
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
221+
a::T # first element of the range
222+
bw::Int # bit width
223+
k::U # range length or zero for full range
224+
u::U # rejection threshold
225225
end
226226

227227
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
228228
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
229229
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
230230

231-
function SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger}
231+
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
232+
SamplerRangeInt(r, uint_sup(T))
233+
234+
function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
232235
isempty(r) && throw(ArgumentError("range must be non-empty"))
233-
SamplerRangeInt(first(r), (last(r) - first(r)) % uint_sup(T))
236+
a = first(r)
237+
m = (last(r) - first(r)) % U
238+
k = m + one(U)
239+
bw = (sizeof(U) << 3 - leading_zeros(m)) % Int
240+
mult = if U === UInt32
241+
maxmultiple(k)
242+
elseif U === UInt64
243+
bw <= 52 ? unsafe_maxmultiple(k, one(UInt64) << 52) :
244+
maxmultiple(k)
245+
else # U === UInt128
246+
bw <= 52 ? unsafe_maxmultiple(k, one(UInt128) << 52) :
247+
bw <= 104 ? unsafe_maxmultiple(k, one(UInt128) << 104) :
248+
maxmultiple(k)
249+
end
250+
251+
SamplerRangeInt{T,U}(a, bw, k, mult) # overflow ok
234252
end
235253

236-
Sampler(::AbstractRNG, r::AbstractUnitRange{T}, ::Repetition) where {T<:Union{Bool,BitInteger}} =
237-
SamplerRangeInt(r)
254+
Sampler(::AbstractRNG, r::AbstractUnitRange{T},
255+
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
238256

239-
function rand_lteq(rng::AbstractRNG, u::T)::T where T
257+
function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
240258
while true
241-
x = rand(rng, T)
259+
x = rand(rng, S)
242260
x <= u && return x
243261
end
244262
end
245263

246-
# this function uses 32 bit entropy for small ranges of length <= typemax(UInt32) + 1
264+
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
265+
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, Val(UInt32), sp.u), sp.k)) % T
266+
267+
# this function uses 52 bit entropy for small ranges of length <= 2^52
247268
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
248-
x::UInt64 = (sp.k - 1) >> 32 == 0 ?
249-
rand_lteq(rng, sp.u % UInt32) % UInt64 :
250-
rand_lteq(rng, sp.u)
269+
x = sp.bw <= 52 ? rand_lteq(rng, UInt52(), sp.u) :
270+
rand_lteq(rng, Val(UInt64), sp.u)
251271
return ((sp.a % UInt64) + rem_knuth(x, sp.k)) % T
252272
end
253273

254-
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,U}) where {T<:Union{Bool,BitInteger},U} =
255-
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, sp.u), sp.k)) % T
274+
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt128}) where T<:BitInteger
275+
x = sp.bw <= 52 ? rand_lteq(rng, UInt52(UInt128), sp.u) :
276+
sp.bw <= 104 ? rand_lteq(rng, UInt104(UInt128), sp.u) :
277+
rand_lteq(rng, Val(UInt128), sp.u)
278+
return ((sp.a % UInt128) + rem_knuth(x, sp.k)) % T
279+
end
256280

257281

258282
### BigInt

test/random.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ if sizeof(Int32) < sizeof(Int)
102102
local r = rand(Int32(-1):typemax(Int32))
103103
@test typeof(r) == Int32
104104
@test -1 <= r <= typemax(Int32)
105-
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(UInt64,1:k)).u for k in 13 .+ Int64(2).^(32:62)])
106-
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(Int64,1:k)).u for k in 13 .+ Int64(2).^(32:61)])
105+
for U = (Int64, UInt64)
106+
@test all(div(one(UInt128) << 52, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
107+
for k in 13 .+ Int64(2).^(32:51))
108+
@test all(div(one(UInt128) << 64, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
109+
for k in 13 .+ Int64(2).^(52:62))
110+
end
107111

108-
@test Base.Random._maxmultiple(0x000100000000) === 0xffffffffffffffff
109-
@test Base.Random._maxmultiple(0x0000FFFFFFFF) === 0x00000000fffffffe
110-
@test Base.Random._maxmultiple(0x000000000000) === 0xffffffffffffffff
111112
end
112113

113114
# BigInt specific
@@ -224,8 +225,11 @@ guardsrand() do
224225
@test r == rand(map(UInt64, 97:122))
225226
end
226227

227-
@test all([div(0x000100000000,k)*k - 1 == Base.Random.RangeGenerator(map(UInt64,1:k)).u for k in 13 .+ Int64(2).^(1:30)])
228-
@test all([div(0x000100000000,k)*k - 1 == Base.Random.RangeGenerator(map(Int64,1:k)).u for k in 13 .+ Int64(2).^(1:30)])
228+
for U in (Int64, UInt64)
229+
@test all(div(one(UInt64) << 52, k)*k - 1 == SamplerRangeInt(map(U, 1:k)).u
230+
for k in 13 .+ Int64(2).^(1:30))
231+
end
232+
229233

230234
import Base.Random: uuid1, uuid4, UUID, uuid_version
231235

0 commit comments

Comments
 (0)