Skip to content

Commit 2c91d7f

Browse files
authored
slightly faster rand(big(a):big(b)) (#41066)
For generating `x` in a `0:m` BigInt range, instead of uniformly randomizing all the limbs of `x` and rejecting when out of range, we generate the highest limb `hx` of `x` in `0:hm` where `hm` is the highest limb of `m`. Before the introduction of `SamplerRangeNDL`, this would have changed nothing, as the predecessor `SamplerRangeFast` was itself using rejection sampling by generating first a given number of random bits. But with NDL, this can speed-up `BigInt` generation by very roughly 10%, as in general almost no "rejection" will happen.
1 parent 20c0baf commit 2c91d7f

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

stdlib/Random/src/generation.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -358,45 +358,56 @@ end
358358

359359
### BigInt
360360

361-
struct SamplerBigInt <: Sampler{BigInt}
361+
struct SamplerBigInt{SP<:Sampler{Limb}} <: Sampler{BigInt}
362362
a::BigInt # first
363363
m::BigInt # range length - 1
364364
nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m])
365365
nlimbsmax::Int # max number of limbs for z+a
366-
mask::Limb # applied to the highest limb
366+
highsp::SP # sampler for the highest limb of z
367367
end
368368

369-
function SamplerBigInt(r::AbstractUnitRange{BigInt})
369+
function SamplerBigInt(::Type{RNG}, r::AbstractUnitRange{BigInt}, N::Repetition=Val(Inf)
370+
) where {RNG<:AbstractRNG}
370371
m = last(r) - first(r)
371-
m < 0 && throw(ArgumentError("range must be non-empty"))
372-
nd = ndigits(m, base=2)
373-
nlimbs, highbits = divrem(nd, 8*sizeof(Limb))
374-
highbits > 0 && (nlimbs += 1)
375-
mask = highbits == 0 ? ~zero(Limb) : one(Limb)<<highbits - one(Limb)
372+
m.size < 0 && throw(ArgumentError("range must be non-empty"))
373+
nlimbs = Int(m.size)
374+
hm = nlimbs == 0 ? Limb(0) : GC.@preserve m unsafe_load(m.d, nlimbs)
375+
highsp = Sampler(RNG, Limb(0):hm, N)
376376
nlimbsmax = max(nlimbs, abs(last(r).size), abs(first(r).size))
377-
return SamplerBigInt(first(r), m, nlimbs, nlimbsmax, mask)
377+
return SamplerBigInt(first(r), m, nlimbs, nlimbsmax, highsp)
378378
end
379379

380-
Sampler(::Type{<:AbstractRNG}, r::AbstractUnitRange{BigInt}, ::Repetition) = SamplerBigInt(r)
380+
Sampler(::Type{RNG}, r::AbstractUnitRange{BigInt}, N::Repetition) where {RNG<:AbstractRNG} =
381+
SamplerBigInt(RNG, r, N)
381382

382383
rand(rng::AbstractRNG, sp::SamplerBigInt) =
383384
rand!(rng, BigInt(nbits = sp.nlimbsmax*8*sizeof(Limb)), sp)
384385

385386
function rand!(rng::AbstractRNG, x::BigInt, sp::SamplerBigInt)
387+
nlimbs = sp.nlimbs
388+
nlimbs == 0 && return MPZ.set!(x, sp.a)
386389
MPZ.realloc2!(x, sp.nlimbsmax*8*sizeof(Limb))
390+
@assert x.alloc >= nlimbs
391+
# we randomize x ∈ [0, m] with rejection sampling:
392+
# 1. the first nlimbs-1 limbs of x are uniformly randomized
393+
# 2. the high limb hx of x is sampled from 0:hm where hm is the
394+
# high limb of m
395+
# We repeat 1. and 2. until x <= m
396+
hm = GC.@preserve sp unsafe_load(sp.m.d, nlimbs)
387397
GC.@preserve x begin
388-
limbs = UnsafeView(x.d, sp.nlimbs)
398+
limbs = UnsafeView(x.d, nlimbs-1)
389399
while true
390400
rand!(rng, limbs)
391-
limbs[end] &= sp.mask
392-
MPZ.mpn_cmp(x, sp.m, sp.nlimbs) <= 0 && break
401+
hx = limbs[nlimbs] = rand(rng, sp.highsp)
402+
hx < hm && break # avoid calling mpn_cmp most of the time
403+
MPZ.mpn_cmp(x, sp.m, nlimbs) <= 0 && break
393404
end
394405
# adjust x.size (normally done by mpz_limbs_finish, in GMP version >= 6)
395-
x.size = sp.nlimbs
396-
while x.size > 0
397-
limbs[x.size] != 0 && break
398-
x.size -= 1
406+
while nlimbs > 0
407+
limbs[nlimbs] != 0 && break
408+
nlimbs -= 1
399409
end
410+
x.size = nlimbs
400411
end
401412
MPZ.add!(x, sp.a)
402413
end

stdlib/Random/test/runtests.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,15 +909,26 @@ end
909909
@test m == MersenneTwister(0, (0, 2256, 1254, 1, 0, 1))
910910
end
911911

912-
@testset "rand! for BigInt/BigFloat" begin
912+
@testset "rand[!] for BigInt/BigFloat" begin
913913
rng = MersenneTwister()
914-
s = Random.SamplerBigInt(1:big(9))
914+
s = Random.SamplerBigInt(MersenneTwister, 1:big(9))
915915
x = rand(s)
916916
@test x isa BigInt
917917
y = rand!(rng, x, s)
918918
@test y === x
919919
@test x in 1:9
920920

921+
for t = BigInt[0, 10, big(2)^100]
922+
s = Random.Sampler(rng, t:t) # s.nlimbs == 0
923+
@test rand(rng, s) == t
924+
@test x === rand!(rng, x, s) == t
925+
926+
s = Random.Sampler(rng, big(-1):t) # s.nlimbs != 0
927+
@test rand(rng, s) -1:t
928+
@test x === rand!(rng, x, s) -1:t
929+
930+
end
931+
921932
s = Random.Sampler(MersenneTwister, Random.CloseOpen01(BigFloat))
922933
x = rand(s)
923934
@test x isa BigFloat

0 commit comments

Comments
 (0)