Skip to content

Commit e058cc0

Browse files
committed
and a dedicated method for rand(::SamplerRangeFast{UInt32})
This could make it slightly faster in some cases, by avoiding a branch, but the main idea so to prepare the unification of code between this Sampler and SamplerRangeInt
1 parent c384517 commit e058cc0

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

base/random/generation.jl

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,30 @@ end
147147
# 2) "Default" which tries to use as few entropy bits as possible, at the cost of a
148148
# a bigger upfront price associated with the creation of the sampler
149149

150+
#### helper functions
151+
152+
function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
153+
while true
154+
x = rand(r, S) & mask
155+
x <= u && return x
156+
end
157+
end
158+
159+
function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
160+
while true
161+
x = rand(rng, S)
162+
x <= u && return x
163+
end
164+
end
165+
166+
# helper function, to turn types to values, should be removed once we
167+
# can do rand(Uniform(UInt))
168+
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)
169+
170+
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
171+
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
172+
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
173+
150174
#### Fast
151175

152176
struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
@@ -156,32 +180,23 @@ struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
156180
mask::U # mask generated values before threshold rejection
157181
end
158182

159-
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
160-
isempty(r) && throw(ArgumentError("range must be non-empty"))
161-
m = last(r) % UInt64 - first(r) % UInt64
162-
bw = (64 - leading_zeros(m)) % UInt # bit-width
163-
mask = (1 % UInt64 << bw) - (1 % UInt64)
164-
SamplerRangeFast{UInt64,T}(first(r), bw, m, mask)
165-
end
183+
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
184+
SamplerRangeFast(r, uint_sup(T))
166185

167-
function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Int128,UInt128}
186+
function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
168187
isempty(r) && throw(ArgumentError("range must be non-empty"))
169-
m = (last(r)-first(r)) % UInt128
170-
bw = (128 - leading_zeros(m)) % UInt # bit-width
171-
mask = (1 % UInt128 << bw) - (1 % UInt128)
172-
SamplerRangeFast{UInt128,T}(first(r), bw, m, mask)
188+
m = (last(r) - first(r)) % U
189+
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
190+
mask = (1 % U << bw) - (1 % U)
191+
SamplerRangeFast{U,T}(first(r), bw, m, mask)
173192
end
174193

175-
function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
176-
while true
177-
x = rand(r, S) & mask
178-
x <= u && return x
179-
end
194+
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
195+
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
196+
x = rand_lteq(rng, Val(UInt32), m, mask)
197+
(x + a % UInt32) % T
180198
end
181199

182-
# helper function, to turn types to values, should be removed once we can do rand(Uniform(UInt))
183-
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)
184-
185200
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
186201
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
187202
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m, mask) :
@@ -216,17 +231,13 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
216231
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
217232
div(sup, k + (k == 0))*k - one(k)
218233

219-
220234
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
221235
a::T # first element of the range
222236
bw::Int # bit width
223237
k::U # range length or zero for full range
224238
u::U # rejection threshold
225239
end
226240

227-
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
228-
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
229-
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
230241

231242
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
232243
SamplerRangeInt(r, uint_sup(T))
@@ -254,13 +265,6 @@ end
254265
Sampler(::AbstractRNG, r::AbstractUnitRange{T},
255266
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
256267

257-
function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
258-
while true
259-
x = rand(rng, S)
260-
x <= u && return x
261-
end
262-
end
263-
264268
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
265269
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, Val(UInt32), sp.u), sp.k)) % T
266270

0 commit comments

Comments
 (0)