Skip to content

Commit c8dd3fd

Browse files
rfourquetandreasnoack
authored andcommitted
simplify rand(::UnitRange) a bit (#24989)
1 parent 40e7d29 commit c8dd3fd

File tree

4 files changed

+38
-48
lines changed

4 files changed

+38
-48
lines changed

base/random/RNGs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
66
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
7-
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, Base.BitInteger})
7+
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger})
88

99
if Sys.iswindows()
1010
struct RandomDevice <: AbstractRNG
@@ -30,7 +30,7 @@ else # !windows
3030
end # os-test
3131

3232
# NOTE: this can't be put within the if-else block above
33-
for T in (Bool, Base.BitInteger_types...)
33+
for T in (Bool, BitInteger_types...)
3434
if Sys.iswindows()
3535
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::SamplerType{$T})
3636
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
@@ -403,7 +403,7 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
403403
A
404404
end
405405

406-
for T in Base.BitInteger_types
406+
for T in BitInteger_types
407407
T === UInt128 && continue
408408
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerType{$T})
409409
n = length(A)
@@ -451,7 +451,7 @@ function rand(rng::MersenneTwister,
451451
x % T + first(r)
452452
end
453453

454-
for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
454+
for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise
455455
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
456456
SamplerTrivial(r)
457457
end

base/random/generation.jl

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -133,68 +133,57 @@ rem_knuth(a::T, b::T) where {T<:Unsigned} = b != 0 ? a % b : a
133133
# that is 0xFFFF...FFFF if k = typemax(T) - typemin(T) with intentional underflow
134134
# see http://stackoverflow.com/questions/29182036/integer-arithmetic-add-1-to-uint-max-and-divide-by-n-without-overflow
135135
maxmultiple(k::T) where {T<:Unsigned} =
136-
(div(typemax(T) - k + oneunit(k), k + (k == 0))*k + k - oneunit(k))::T
136+
(div(typemax(T) - k + one(k), k + (k == 0))*k + k - one(k))::T
137+
138+
# serves as rejection threshold
139+
_maxmultiple(k) = maxmultiple(k)
137140

138141
# maximum multiple of k within 1:2^32 or 1:2^64 decremented by one, depending on size
139-
maxmultiplemix(k::UInt64) = k >> 32 != 0 ?
142+
_maxmultiple(k::UInt64)::UInt64 = k >> 32 != 0 ?
140143
maxmultiple(k) :
141-
(div(0x0000000100000000, k + (k == 0))*k - oneunit(k))::UInt64
144+
div(0x0000000100000000, k + (k == 0))*k - one(k)
142145

143-
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
146+
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
144147
a::T # first element of the range
145148
k::U # range length or zero for full range
146149
u::U # rejection threshold
147150
end
148151

149-
# generators with 32, 128 bits entropy
150-
SamplerRangeInt(a::T, k::U) where {T,U<:Union{UInt32,UInt128}} =
151-
SamplerRangeInt{T,U}(a, k, maxmultiple(k))
152+
function SamplerRangeInt(a::T, diff::U) where {T<:Union{Bool,Integer},U<:Unsigned}
153+
k = diff+one(U)
154+
SamplerRangeInt{T,U}(a, k, _maxmultiple(k)) # overflow ok
155+
end
152156

153-
# mixed 32/64 bits entropy generator
154-
SamplerRangeInt(a::T, k::UInt64) where {T} =
155-
SamplerRangeInt{T,UInt64}(a, k, maxmultiplemix(k))
157+
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
158+
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
159+
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
156160

157-
function Sampler(::AbstractRNG, r::AbstractUnitRange{T}, ::Repetition) where T<:Unsigned
161+
function SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger}
158162
isempty(r) && throw(ArgumentError("range must be non-empty"))
159-
SamplerRangeInt(first(r), last(r) - first(r) + oneunit(T))
163+
SamplerRangeInt(first(r), (last(r) - first(r)) % uint_sup(T))
160164
end
161165

162-
for (T, U) in [(UInt8, UInt32), (UInt16, UInt32),
163-
(Int8, UInt32), (Int16, UInt32), (Int32, UInt32),
164-
(Int64, UInt64), (Int128, UInt128), (Bool, UInt32)]
166+
Sampler(::AbstractRNG, r::AbstractUnitRange{T}, ::Repetition) where {T<:Union{Bool,BitInteger}} =
167+
SamplerRangeInt(r)
165168

166-
@eval Sampler(::AbstractRNG, r::AbstractUnitRange{$T}, ::Repetition) = begin
167-
isempty(r) && throw(ArgumentError("range must be non-empty"))
168-
# overflow ok:
169-
SamplerRangeInt(first(r), convert($U, unsigned(last(r) - first(r)) + one($U)))
169+
function rand_lteq(rng::AbstractRNG, u::T)::T where T
170+
while true
171+
x = rand(rng, T)
172+
x <= u && return x
170173
end
171174
end
172175

173176
# this function uses 32 bit entropy for small ranges of length <= typemax(UInt32) + 1
174-
# SamplerRangeInt is responsible for providing the right value of k
175-
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:Union{UInt64,Int64}
176-
local x::UInt64
177-
if (sp.k - 1) >> 32 == 0
178-
x = rand(rng, UInt32)
179-
while x > sp.u
180-
x = rand(rng, UInt32)
181-
end
182-
else
183-
x = rand(rng, UInt64)
184-
while x > sp.u
185-
x = rand(rng, UInt64)
186-
end
187-
end
188-
return reinterpret(T, reinterpret(UInt64, sp.a) + rem_knuth(x, sp.k))
177+
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
178+
x::UInt64 = (sp.k - 1) >> 32 == 0 ?
179+
rand_lteq(rng, sp.u % UInt32) % UInt64 :
180+
rand_lteq(rng, sp.u)
181+
return ((sp.a % UInt64) + rem_knuth(x, sp.k)) % T
189182
end
190183

191-
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,U}) where {T<:Integer,U<:Unsigned}
192-
x = rand(rng, U)
193-
while x > sp.u
194-
x = rand(rng, U)
195-
end
196-
(unsigned(sp.a) + rem_knuth(x, sp.k)) % T
197-
end
184+
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,U}) where {T<:Union{Bool,BitInteger},U} =
185+
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, sp.u), sp.k)) % T
186+
198187

199188
### BigInt
200189

base/random/random.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module Random
44

55
using Base.dSFMT
66
using Base.GMP: Limb, MPZ
7+
using Base: BitInteger, BitInteger_types
78
import Base: copymutable, copy, copy!, ==, hash
89

910
export srand,

test/random.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ if sizeof(Int32) < sizeof(Int)
103103
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(UInt64,1:k)).u for k in 13 .+ Int64(2).^(32:62)])
104104
@test all([div(0x00010000000000000000,k)*k - 1 == Base.Random.RangeGenerator(map(Int64,1:k)).u for k in 13 .+ Int64(2).^(32:61)])
105105

106-
@test Base.Random.maxmultiplemix(0x000100000000) === 0xffffffffffffffff
107-
@test Base.Random.maxmultiplemix(0x0000FFFFFFFF) === 0x00000000fffffffe
108-
@test Base.Random.maxmultiplemix(0x000000000000) === 0xffffffffffffffff
106+
@test Base.Random._maxmultiple(0x000100000000) === 0xffffffffffffffff
107+
@test Base.Random._maxmultiple(0x0000FFFFFFFF) === 0x00000000fffffffe
108+
@test Base.Random._maxmultiple(0x000000000000) === 0xffffffffffffffff
109109
end
110110

111111
# BigInt specific

0 commit comments

Comments
 (0)