147
147
# 2) "Default" which tries to use as few entropy bits as possible, at the cost of a
148
148
# a bigger upfront price associated with the creation of the sampler
149
149
150
+ # ### helper functions
151
+
152
+ function rand_lteq (r:: AbstractRNG , S, u:: U , mask:: U ) where U<: Integer
153
+ while true
154
+ x = rand (r, S) & mask
155
+ x <= u && return x
156
+ end
157
+ end
158
+
159
+ function rand_lteq (rng:: AbstractRNG , S, u:: T ):: T where T
160
+ while true
161
+ x = rand (rng, S)
162
+ x <= u && return x
163
+ end
164
+ end
165
+
166
+ # helper function, to turn types to values, should be removed once we
167
+ # can do rand(Uniform(UInt))
168
+ rand (rng:: AbstractRNG , :: Val{T} ) where {T} = rand (rng, T)
169
+
170
+ uint_sup (:: Type{<:Union{Bool,BitInteger}} ) = UInt32
171
+ uint_sup (:: Type{<:Union{Int64,UInt64}} ) = UInt64
172
+ uint_sup (:: Type{<:Union{Int128,UInt128}} ) = UInt128
173
+
150
174
# ### Fast
151
175
152
176
struct SamplerRangeFast{U<: BitUnsigned ,T<: Union{BitInteger,Bool} } <: Sampler
@@ -156,32 +180,23 @@ struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
156
180
mask:: U # mask generated values before threshold rejection
157
181
end
158
182
159
- function SamplerRangeFast (r:: AbstractUnitRange{T} ) where T<: Union{Base.BitInteger64,Bool}
160
- isempty (r) && throw (ArgumentError (" range must be non-empty" ))
161
- m = last (r) % UInt64 - first (r) % UInt64
162
- bw = (64 - leading_zeros (m)) % UInt # bit-width
163
- mask = (1 % UInt64 << bw) - (1 % UInt64)
164
- SamplerRangeFast {UInt64,T} (first (r), bw, m, mask)
165
- end
183
+ SamplerRangeFast (r:: AbstractUnitRange{T} ) where T<: Union{Bool,BitInteger} =
184
+ SamplerRangeFast (r, uint_sup (T))
166
185
167
- function SamplerRangeFast (r:: AbstractUnitRange{T} ) where T <: Union{Int128,UInt128 }
186
+ function SamplerRangeFast (r:: AbstractUnitRange{T} , :: Type{U} ) where {T,U }
168
187
isempty (r) && throw (ArgumentError (" range must be non-empty" ))
169
- m = (last (r)- first (r)) % UInt128
170
- bw = (128 - leading_zeros (m)) % UInt # bit-width
171
- mask = (1 % UInt128 << bw) - (1 % UInt128 )
172
- SamplerRangeFast {UInt128 ,T} (first (r), bw, m, mask)
188
+ m = (last (r) - first (r)) % U
189
+ bw = (sizeof (U) << 3 - leading_zeros (m)) % UInt # bit-width
190
+ mask = (1 % U << bw) - (1 % U )
191
+ SamplerRangeFast {U ,T} (first (r), bw, m, mask)
173
192
end
174
193
175
- function rand_lteq (r:: AbstractRNG , S, u:: U , mask:: U ) where U<: Integer
176
- while true
177
- x = rand (r, S) & mask
178
- x <= u && return x
179
- end
194
+ function rand (rng:: AbstractRNG , sp:: SamplerRangeFast{UInt32,T} ) where T
195
+ a, bw, m, mask = sp. a, sp. bw, sp. m, sp. mask
196
+ x = rand_lteq (rng, Val (UInt32), m, mask)
197
+ (x + a % UInt32) % T
180
198
end
181
199
182
- # helper function, to turn types to values, should be removed once we can do rand(Uniform(UInt))
183
- rand (rng:: AbstractRNG , :: Val{T} ) where {T} = rand (rng, T)
184
-
185
200
function rand (rng:: AbstractRNG , sp:: SamplerRangeFast{UInt64,T} ) where T
186
201
a, bw, m, mask = sp. a, sp. bw, sp. m, sp. mask
187
202
x = bw <= 52 ? rand_lteq (rng, UInt52Raw (), m, mask) :
@@ -216,17 +231,13 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
216
231
unsafe_maxmultiple (k:: T , sup:: T ) where {T<: Unsigned } =
217
232
div (sup, k + (k == 0 ))* k - one (k)
218
233
219
-
220
234
struct SamplerRangeInt{T<: Union{Bool,Integer} ,U<: Unsigned } <: Sampler
221
235
a:: T # first element of the range
222
236
bw:: Int # bit width
223
237
k:: U # range length or zero for full range
224
238
u:: U # rejection threshold
225
239
end
226
240
227
- uint_sup (:: Type{<:Union{Bool,BitInteger}} ) = UInt32
228
- uint_sup (:: Type{<:Union{Int64,UInt64}} ) = UInt64
229
- uint_sup (:: Type{<:Union{Int128,UInt128}} ) = UInt128
230
241
231
242
SamplerRangeInt (r:: AbstractUnitRange{T} ) where T<: Union{Bool,BitInteger} =
232
243
SamplerRangeInt (r, uint_sup (T))
254
265
Sampler (:: AbstractRNG , r:: AbstractUnitRange{T} ,
255
266
:: Repetition ) where {T<: Union{Bool,BitInteger} } = SamplerRangeInt (r)
256
267
257
- function rand_lteq (rng:: AbstractRNG , S, u:: T ):: T where T
258
- while true
259
- x = rand (rng, S)
260
- x <= u && return x
261
- end
262
- end
263
-
264
268
rand (rng:: AbstractRNG , sp:: SamplerRangeInt{T,UInt32} ) where {T<: Union{Bool,BitInteger} } =
265
269
(unsigned (sp. a) + rem_knuth (rand_lteq (rng, Val (UInt32), sp. u), sp. k)) % T
266
270
0 commit comments