Skip to content

Commit 5bb6b4d

Browse files
committed
unify the 3 rand_lt[eq] functions as composable samplers
1 parent e058cc0 commit 5bb6b4d

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

base/random/generation.jl

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,6 @@ end
149149

150150
#### helper functions
151151

152-
function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
153-
while true
154-
x = rand(r, S) & mask
155-
x <= u && return x
156-
end
157-
end
158-
159-
function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
160-
while true
161-
x = rand(rng, S)
162-
x <= u && return x
163-
end
164-
end
165-
166-
# helper function, to turn types to values, should be removed once we
167-
# can do rand(Uniform(UInt))
168-
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)
169-
170152
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
171153
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
172154
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
@@ -193,22 +175,24 @@ end
193175

194176
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
195177
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
196-
x = rand_lteq(rng, Val(UInt32), m, mask)
178+
x = rand(rng, LessThan(m, Masked(mask, uniform(UInt32))))
197179
(x + a % UInt32) % T
198180
end
199181

200182
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
201183
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
202-
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m, mask) :
203-
rand_lteq(rng, Val(UInt64), m, mask)
184+
x = bw <= 52 ? rand(rng, LessThan(m, Masked(mask, UInt52Raw()))) :
185+
rand(rng, LessThan(m, Masked(mask, uniform(UInt64))))
204186
(x + a % UInt64) % T
205187
end
206188

207189
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt128,T}) where T
208190
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
209-
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m % UInt64, mask % UInt64) % UInt128 :
210-
bw <= 104 ? rand_lteq(rng, UInt104Raw(), m, mask) :
211-
rand_lteq(rng, Val(UInt128), m, mask)
191+
x = bw <= 52 ?
192+
rand(rng, LessThan(m % UInt64, Masked(mask % UInt64, UInt52Raw()))) % UInt128 :
193+
bw <= 104 ?
194+
rand(rng, LessThan(m, Masked(mask, UInt104Raw()))) :
195+
rand(rng, LessThan(m, Masked(mask, uniform(UInt128))))
212196
x % T + a
213197
end
214198

@@ -266,19 +250,20 @@ Sampler(::AbstractRNG, r::AbstractUnitRange{T},
266250
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
267251

268252
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
269-
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, Val(UInt32), sp.u), sp.k)) % T
253+
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T
254+
270255

271256
# this function uses 52 bit entropy for small ranges of length <= 2^52
272257
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
273-
x = sp.bw <= 52 ? rand_lteq(rng, UInt52(), sp.u) :
274-
rand_lteq(rng, Val(UInt64), sp.u)
258+
x = sp.bw <= 52 ? rand(rng, LessThan(sp.u, UInt52())) :
259+
rand(rng, LessThan(sp.u, uniform(UInt64)))
275260
return ((sp.a % UInt64) + rem_knuth(x, sp.k)) % T
276261
end
277262

278263
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt128}) where T<:BitInteger
279-
x = sp.bw <= 52 ? rand_lteq(rng, UInt52(UInt128), sp.u) :
280-
sp.bw <= 104 ? rand_lteq(rng, UInt104(UInt128), sp.u) :
281-
rand_lteq(rng, Val(UInt128), sp.u)
264+
x = sp.bw <= 52 ? rand(rng, LessThan(sp.u, UInt52(UInt128))) :
265+
sp.bw <= 104 ? rand(rng, LessThan(sp.u, UInt104(UInt128))) :
266+
rand(rng, LessThan(sp.u, uniform(UInt128)))
282267
return ((sp.a % UInt128) + rem_knuth(x, sp.k)) % T
283268
end
284269

base/random/misc.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,10 @@ large.) Technically, this process is known as "Bernoulli sampling" of `A`.
141141
randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
142142

143143

144-
## rand_lt (helper function)
145-
146-
"Return a random `Int` (masked with `mask`) in ``[0, n)``, when `n <= 2^52`."
147-
@inline function rand_lt(r::AbstractRNG, n::Int, mask::Int=nextpow2(n)-1)
148-
# this duplicates the functionality of rand(1:n), to optimize this special case
149-
while true
150-
x = rand(r, UInt52Raw(Int)) & mask
151-
x < n && return x
152-
end
153-
end
144+
## rand Less Than Masked 52 bits (helper function)
154145

146+
"Return a sampler generating a random `Int` (masked with `mask`) in ``[0, n)``, when `n <= 2^52`."
147+
ltm52(n::Int, mask::Int=nextpow2(n)-1) = LessThan(n-1, Masked(mask, UInt52Raw(Int)))
155148

156149
## shuffle & shuffle!
157150

@@ -191,7 +184,7 @@ function shuffle!(r::AbstractRNG, a::AbstractArray)
191184
mask = nextpow2(n) - 1
192185
for i = n:-1:2
193186
(mask >> 1) == i && (mask >>= 1)
194-
j = 1 + rand_lt(r, i, mask)
187+
j = 1 + rand(r, ltm52(i, mask))
195188
a[i], a[j] = a[j], a[i]
196189
end
197190
return a
@@ -277,7 +270,7 @@ function randperm!(r::AbstractRNG, a::Array{<:Integer})
277270
a[1] = 1
278271
mask = 3
279272
@inbounds for i = 2:n
280-
j = 1 + rand_lt(r, i, mask)
273+
j = 1 + rand(r, ltm52(i, mask))
281274
if i != j # a[i] is uninitialized (and could be #undef)
282275
a[i] = a[j]
283276
end
@@ -339,7 +332,7 @@ function randcycle!(r::AbstractRNG, a::Array{<:Integer})
339332
a[1] = 1
340333
mask = 3
341334
@inbounds for i = 2:n
342-
j = 1 + rand_lt(r, i-1, mask)
335+
j = 1 + rand(r, ltm52(i-1, mask))
343336
a[i] = a[j]
344337
a[j] = i
345338
i == 1+mask && (mask = 2mask + 1)

base/random/random.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Sampler(rng::AbstractRNG, sp::Sampler, ::Repetition) =
105105
Sampler(rng::AbstractRNG, X) = Sampler(rng, X, Val(Inf))
106106
Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
107107

108-
#### pre-defined useful Sampler subtypes
108+
#### pre-defined useful Sampler types
109109

110110
# default fall-back for types
111111
struct SamplerType{T} <: Sampler end
@@ -138,6 +138,38 @@ struct SamplerTag{T,S} <: Sampler
138138
end
139139

140140

141+
#### helper samplers
142+
143+
##### Adapter to generate a randome value in [0, n]
144+
145+
struct LessThan{T<:Integer,S} <: Sampler
146+
sup::T
147+
s::S # the scalar specification/sampler to feed to rand
148+
end
149+
150+
function rand(rng::AbstractRNG, sp::LessThan)
151+
while true
152+
x = rand(rng, sp.s)
153+
x <= sp.sup && return x
154+
end
155+
end
156+
157+
struct Masked{T<:Integer,S} <: Sampler
158+
mask::T
159+
s::S
160+
end
161+
162+
rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask
163+
164+
##### Uniform
165+
166+
struct UniformT{T} <: Sampler end
167+
168+
uniform(::Type{T}) where {T} = UniformT{T}()
169+
170+
rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
171+
172+
141173
### machinery for generation with Sampler
142174

143175
# This describes how to generate random scalars or arrays, by generating a Sampler

0 commit comments

Comments
 (0)