Skip to content

Commit f1baa90

Browse files
authored
Merge pull request #24885 from JuliaLang/rf/rand/ptr
fix some FIXMEs in random, and take 2 to #9174
2 parents 24fce75 + 3c1d672 commit f1baa90

File tree

4 files changed

+124
-69
lines changed

4 files changed

+124
-69
lines changed

base/random/RNGs.jl

Lines changed: 107 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
143143
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]
144144

145145
function gen_rand(r::MersenneTwister)
146-
Base.@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
146+
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
147147
mt_setfull!(r)
148148
end
149149

@@ -264,12 +264,14 @@ rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
264264

265265
#### arrays of floats
266266

267-
function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float64},
268-
region::AbstractUnitRange, I::FloatInterval_64)
269-
# region should be a subset of linearindices(A)
267+
##### AbstractArray
268+
269+
function rand!(r::MersenneTwister, A::AbstractArray{Float64},
270+
I::SamplerTrivial{<:FloatInterval_64})
271+
region = linearindices(A)
270272
# what follows is equivalent to this simple loop but more efficient:
271273
# for i=region
272-
# @inbounds A[i] = rand(r, I)
274+
# @inbounds A[i] = rand(r, I[])
273275
# end
274276
m = Base.checked_sub(first(region), 1)
275277
n = last(region)
@@ -281,53 +283,101 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
281283
end
282284
m2 = min(n, m+s)
283285
for i=m+1:m2
284-
@inbounds A[i] = rand_inbounds(r, I)
286+
@inbounds A[i] = rand_inbounds(r, I[])
285287
end
286288
m = m2
287289
end
288290
A
289291
end
290292

291-
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
292-
rand_AbstractArray_Float64!(r, A, linearindices(A), I[])
293+
294+
##### Array : internal functions
295+
296+
# internal array-like type to circumevent the lack of flexibility with reinterpret
297+
struct UnsafeView{T} <: DenseArray{T,1}
298+
ptr::Ptr{T}
299+
len::Int
300+
end
301+
302+
Base.length(a::UnsafeView) = a.len
303+
Base.getindex(a::UnsafeView, i::Int) = unsafe_load(a.ptr, i)
304+
Base.setindex!(a::UnsafeView, x, i::Int) = unsafe_store!(a.ptr, x, i)
305+
Base.pointer(a::UnsafeView) = a.ptr
306+
Base.size(a::UnsafeView) = (a.len,)
307+
308+
# this is essentially equivalent to rand!(r, ::AbstractArray{Float64}, I) above, but due to
309+
# optimizations which can't be done currently when working with pointers, we have to re-order
310+
# manually the computation flow to get the performance
311+
# (see https://discourse.julialang.org/t/unsafe-store-sometimes-slower-than-arrays-setindex)
312+
function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInterval_64)
313+
n = length(A)
314+
@assert n <= dsfmt_get_min_array_size()+1 # == 383
315+
mt_avail(r) == 0 && gen_rand(r)
316+
# from now on, at most one call to gen_rand(r) will be necessary
317+
m = min(n, mt_avail(r))
318+
@gc_preserve r unsafe_copy!(A.ptr, pointer(r.vals, r.idx+1), m)
319+
if m == n
320+
r.idx += m
321+
else # m < n
322+
gen_rand(r)
323+
@gc_preserve r unsafe_copy!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
324+
r.idx = n-m
325+
end
326+
if I isa CloseOpen
327+
for i=1:n
328+
A[i] -= 1.0
329+
end
330+
end
331+
A
332+
end
333+
293334

294335
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
295336
dsfmt_fill_array_close_open!(s, A, n)
296337

297338
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
298339
dsfmt_fill_array_close1_open2!(s, A, n)
299340

300-
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
301-
I::FloatInterval_64)
341+
342+
function rand!(r::MersenneTwister, A::UnsafeView{Float64},
343+
I::SamplerTrivial{<:FloatInterval_64})
302344
# depending on the alignment of A, the data written by fill_array! may have
303345
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
304346
# reproducibility purposes;
305347
# so, even for well aligned arrays, fill_array! is used to generate only
306348
# the n-2 first values (or n-3 if n is odd), and the remaining values are
307349
# generated by the scalar version of rand
308-
n > length(A) && throw(BoundsError(A, n))
350+
n = length(A)
309351
n2 = (n-2) ÷ 2 * 2
310-
if n2 < dsfmt_get_min_array_size()
311-
rand_AbstractArray_Float64!(r, A, 1:n, I)
352+
n2 < dsfmt_get_min_array_size() && return _rand_max383!(r, A, I[])
353+
354+
pA = A.ptr
355+
align = Csize_t(pA) % 16
356+
if align > 0
357+
pA2 = pA + 16 - align
358+
fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted
359+
unsafe_copy!(pA, pA2, n2) # move the data to the beginning of the array
312360
else
313-
pA = pointer(A)
314-
align = Csize_t(pA) % 16
315-
Base.@gc_preserve A if align > 0
316-
pA2 = pA + 16 - align
317-
fill_array!(r.state, pA2, n2, I) # generate the data in-place, but shifted
318-
unsafe_copy!(pA, pA2, n2) # move the data to the beginning of the array
319-
else
320-
fill_array!(r.state, pA, n2, I)
321-
end
322-
for i=n2+1:n
323-
@inbounds A[i] = rand(r, I)
324-
end
361+
fill_array!(r.state, pA, n2, I[])
325362
end
363+
for i=n2+1:n
364+
A[i] = rand(r, I[])
365+
end
366+
A
367+
end
368+
369+
# fills up A reinterpreted as an array of Float64 with n64 values
370+
function _rand!(r::MersenneTwister, A::Array{T}, n64::Int, I::FloatInterval_64) where T
371+
# n64 is the length in terms of `Float64` of the target
372+
@assert sizeof(Float64)*n64 <= sizeof(T)*length(A) && isbits(T)
373+
@gc_preserve A rand!(r, UnsafeView{Float64}(pointer(A), n64), SamplerTrivial(I))
326374
A
327375
end
328376

329-
rand!(r::MersenneTwister, A::Array{Float64}, sp::SamplerTrivial{<:FloatInterval_64}) =
330-
_rand!(r, A, length(A), sp[])
377+
##### Array: Float64, Float16, Float32
378+
379+
rand!(r::MersenneTwister, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
380+
_rand!(r, A, length(A), I[])
331381

332382
mask128(u::UInt128, ::Type{Float16}) =
333383
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00
@@ -339,27 +389,27 @@ for T in (Float16, Float32)
339389
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{Close1Open2{$T}})
340390
n = length(A)
341391
n128 = n * sizeof($T) ÷ 16
342-
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
343-
2*n128, Close1Open2())
344-
# FIXME: This code is completely invalid!!!
345-
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
346-
@inbounds for i in 1:n128
347-
u = A128[i]
348-
u ⊻= u << 26
349-
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
350-
# the bit xor, are:
351-
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
352-
# the bits needing to be random are
353-
# [1:10, 17:26, 33:42, 49:58] (for Float16)
354-
# [1:23, 33:55] (for Float32)
355-
# this is obviously satisfied on the 32 low bits side, and on the high side,
356-
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
357-
# (which are discarded on the low side)
358-
# this is similar for the 64 high bits of u
359-
A128[i] = mask128(u, $T)
392+
_rand!(r, A, 2*n128, Close1Open2())
393+
@gc_preserve A begin
394+
A128 = UnsafeView{UInt128}(pointer(A), n128)
395+
for i in 1:n128
396+
u = A128[i]
397+
u ⊻= u << 26
398+
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
399+
# the bit xor, are:
400+
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
401+
# the bits needing to be random are
402+
# [1:10, 17:26, 33:42, 49:58] (for Float16)
403+
# [1:23, 33:55] (for Float32)
404+
# this is obviously satisfied on the 32 low bits side, and on the high side,
405+
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
406+
# (which are discarded on the low side)
407+
# this is similar for the 64 high bits of u
408+
A128[i] = mask128(u, $T)
409+
end
360410
end
361411
for i in 16*n128÷sizeof($T)+1:n
362-
@inbounds A[i] = rand(r, $T) + oneunit($T)
412+
@inbounds A[i] = rand(r, $T) + one($T)
363413
end
364414
A
365415
end
@@ -376,16 +426,14 @@ end
376426

377427
#### arrays of integers
378428

379-
function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
429+
function rand!(r::MersenneTwister, A::UnsafeView{UInt128}, ::SamplerType{UInt128})
380430
n::Int=length(A)
381-
# FIXME: This code is completely invalid!!!
382-
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
383431
i = n
384432
while true
385-
_rand!(r, Af, 2i, Close1Open2())
433+
rand!(r, UnsafeView{Float64}(A.ptr, 2i), Close1Open2())
386434
n < 5 && break
387435
i = 0
388-
@inbounds while n-i >= 5
436+
while n-i >= 5
389437
u = A[i+=1]
390438
A[n] ⊻= u << 48
391439
A[n-=1] ⊻= u << 36
@@ -397,19 +445,22 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
397445
if n > 0
398446
u = rand(r, UInt2x52Raw())
399447
for i = 1:n
400-
@inbounds A[i] ⊻= u << (12*i)
448+
A[i] ⊻= u << (12*i)
401449
end
402450
end
403451
A
404452
end
405453

406454
for T in BitInteger_types
407-
T === UInt128 && continue
408-
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerType{$T})
455+
@eval rand!(r::MersenneTwister, A::Array{$T}, sp::SamplerType{$T}) =
456+
(@gc_preserve A rand!(r, UnsafeView(pointer(A), length(A)), sp); A)
457+
458+
T == UInt128 && continue
459+
460+
@eval function rand!(r::MersenneTwister, A::UnsafeView{$T}, ::SamplerType{$T})
409461
n = length(A)
410462
n128 = n * sizeof($T) ÷ 16
411-
# FIXME: This code is completely invalid!!!
412-
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
463+
rand!(r, UnsafeView{UInt128}(pointer(A), n128))
413464
for i = 16*n128÷sizeof($T)+1:n
414465
@inbounds A[i] = rand(r, $T)
415466
end

base/random/generation.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,19 @@ end
278278

279279
function rand(rng::AbstractRNG, sp::SamplerBigInt)
280280
x = MPZ.realloc2(sp.nlimbsmax*8*sizeof(Limb))
281-
limbs = unsafe_wrap(Array, x.d, sp.nlimbs)
282-
while true
283-
rand!(rng, limbs)
284-
@inbounds limbs[end] &= sp.mask
285-
MPZ.mpn_cmp(x, sp.m, sp.nlimbs) <= 0 && break
286-
end
287-
# adjust x.size (normally done by mpz_limbs_finish, in GMP version >= 6)
288-
x.size = sp.nlimbs
289-
while x.size > 0
290-
@inbounds limbs[x.size] != 0 && break
291-
x.size -= 1
281+
@gc_preserve x begin
282+
limbs = UnsafeView(x.d, sp.nlimbs)
283+
while true
284+
rand!(rng, limbs)
285+
limbs[end] &= sp.mask
286+
MPZ.mpn_cmp(x, sp.m, sp.nlimbs) <= 0 && break
287+
end
288+
# adjust x.size (normally done by mpz_limbs_finish, in GMP version >= 6)
289+
x.size = sp.nlimbs
290+
while x.size > 0
291+
limbs[x.size] != 0 && break
292+
x.size -= 1
293+
end
292294
end
293295
MPZ.add!(x, sp.a)
294296
end

base/random/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ module Random
44

55
using Base.dSFMT
66
using Base.GMP: Limb, MPZ
7-
using Base: BitInteger, BitInteger_types, BitUnsigned
7+
using Base: BitInteger, BitInteger_types, BitUnsigned, @gc_preserve
8+
89
import Base: copymutable, copy, copy!, ==, hash
910

1011
export srand,

test/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ let mt = MersenneTwister(0)
286286
@test rand!(mt, AF64)[end] == 0.957735065345398
287287
@test rand!(mt, AF64)[end] == 0.6492481059865669
288288
resize!(AF64, 2*length(mt.vals))
289-
@test Base.Random.rand_AbstractArray_Float64!(mt, AF64, linearindices(AF64), Base.Random.CloseOpen())[end] == 0.432757268470779
289+
@test invoke(rand!, Tuple{MersenneTwister,AbstractArray{Float64},Base.Random.SamplerTrivial{Base.Random.CloseOpen_64}},
290+
mt, AF64, Base.Random.SamplerTrivial(Base.Random.CloseOpen()))[end] == 0.432757268470779
290291
end
291292

292293
# Issue #9037

0 commit comments

Comments
 (0)