Skip to content

Commit 2da4c18

Browse files
authored
Merge pull request #25002 from JuliaLang/rf/rand/samplerize-MT-range
RNG: make a Sampler out of rand(::MersenneTwister, ::UnitRange)
2 parents 884c956 + 66f580e commit 2da4c18

File tree

4 files changed

+68
-34
lines changed

4 files changed

+68
-34
lines changed

base/random/RNGs.jl

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -419,41 +419,9 @@ end
419419

420420
#### from a range
421421

422-
function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
423-
while true
424-
x = randfun(r) & mask
425-
x <= u && return x
426-
end
427-
end
428-
429-
function rand(rng::MersenneTwister,
430-
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
431-
r = sp[]
432-
isempty(r) && throw(ArgumentError("range must be non-empty"))
433-
m = last(r) % UInt64 - first(r) % UInt64
434-
bw = (64 - leading_zeros(m)) % UInt # bit-width
435-
mask = (1 % UInt64 << bw) - (1 % UInt64)
436-
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m, mask) :
437-
rand_lteq(rng, rng->rand(rng, UInt64), m, mask)
438-
(x + first(r) % UInt64) % T
439-
end
440-
441-
function rand(rng::MersenneTwister,
442-
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
443-
r = sp[]
444-
isempty(r) && throw(ArgumentError("range must be non-empty"))
445-
m = (last(r)-first(r)) % UInt128
446-
bw = (128 - leading_zeros(m)) % UInt # bit-width
447-
mask = (1 % UInt128 << bw) - (1 % UInt128)
448-
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m % UInt64, mask % UInt64) % UInt128 :
449-
bw <= 104 ? rand_lteq(rng, rand_ui104_raw, m, mask) :
450-
rand_lteq(rng, rng->rand(rng, UInt128), m, mask)
451-
x % T + first(r)
452-
end
453-
454422
for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise
455423
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
456-
SamplerTrivial(r)
424+
SamplerRangeFast(r)
457425
end
458426

459427

base/random/generation.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,63 @@ end
125125

126126
### BitInteger
127127

128+
# there are two implemented samplers for unit ranges, which assume that Float64 (i.e.
129+
# 52 random bits) is the native type for the RNG:
130+
# 1) "Fast", which is the most efficient when the underlying RNG produces rand(Float64)
131+
# "fast enough". The tradeoff is faster creation of the sampler, but more
132+
# consumption of entropy bits
133+
# 2) "Default" which tries to use as few entropy bits as possible, at the cost of a
134+
# a bigger upfront price associated with the creation of the sampler
135+
136+
#### Fast
137+
138+
struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
139+
a::T # first element of the range
140+
bw::UInt # bit width
141+
m::U # range length - 1
142+
mask::U # mask generated values before threshold rejection
143+
end
144+
145+
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
146+
isempty(r) && throw(ArgumentError("range must be non-empty"))
147+
m = last(r) % UInt64 - first(r) % UInt64
148+
bw = (64 - leading_zeros(m)) % UInt # bit-width
149+
mask = (1 % UInt64 << bw) - (1 % UInt64)
150+
SamplerRangeFast{UInt64,T}(first(r), bw, m, mask)
151+
end
152+
153+
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Int128,UInt128}
154+
isempty(r) && throw(ArgumentError("range must be non-empty"))
155+
m = (last(r)-first(r)) % UInt128
156+
bw = (128 - leading_zeros(m)) % UInt # bit-width
157+
mask = (1 % UInt128 << bw) - (1 % UInt128)
158+
SamplerRangeFast{UInt128,T}(first(r), bw, m, mask)
159+
end
160+
161+
function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
162+
while true
163+
x = randfun(r) & mask
164+
x <= u && return x
165+
end
166+
end
167+
168+
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
169+
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
170+
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m, mask) :
171+
rand_lteq(rng, rng->rand(rng, UInt64), m, mask)
172+
(x + a % UInt64) % T
173+
end
174+
175+
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt128,T}) where T
176+
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
177+
x = bw <= 52 ? rand_lteq(rng, rand_ui52_raw, m % UInt64, mask % UInt64) % UInt128 :
178+
bw <= 104 ? rand_lteq(rng, rand_ui104_raw, m, mask) :
179+
rand_lteq(rng, rng->rand(rng, UInt128), m, mask)
180+
x % T + a
181+
end
182+
183+
#### Default
184+
128185
# remainder function according to Knuth, where rem_knuth(a, 0) = a
129186
rem_knuth(a::UInt, b::UInt) = a % (b + (b == 0)) + a * (b == 0)
130187
rem_knuth(a::T, b::T) where {T<:Unsigned} = b != 0 ? a % b : a

base/random/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Random
44

55
using Base.dSFMT
66
using Base.GMP: Limb, MPZ
7-
using Base: BitInteger, BitInteger_types
7+
using Base: BitInteger, BitInteger_types, BitUnsigned
88
import Base: copymutable, copy, copy!, ==, hash
99

1010
export srand,

test/random.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
isdefined(Main, :TestHelpers) || @eval Main include(joinpath(dirname(@__FILE__), "TestHelpers.jl"))
44
using Main.TestHelpers.OAs
55

6+
using Base.Random: Sampler, SamplerRangeFast, SamplerRangeInt
7+
68
# Issue #6573
79
guardsrand(0) do
810
rand()
@@ -647,3 +649,10 @@ struct RandomStruct23964 end
647649
@test_throws ArgumentError rand(nothing)
648650
@test_throws ArgumentError rand(RandomStruct23964())
649651
end
652+
653+
@testset "rand(::$RNG, ::UnitRange{$T}" for RNG (MersenneTwister(), RandomDevice()),
654+
T (Int32, UInt32, Int64, Int128, UInt128)
655+
RNG isa MersenneTwister && srand(RNG, rand(UInt128)) # for reproducibility
656+
r = T(1):T(108)
657+
@test rand(RNG, SamplerRangeFast(r)) r
658+
end

0 commit comments

Comments
 (0)