Skip to content

Commit 8a751b2

Browse files
committed
fix some FIXMEs in random, and take 2 to #9174
1 parent 295b098 commit 8a751b2

File tree

3 files changed

+103
-57
lines changed

3 files changed

+103
-57
lines changed

base/random/RNGs.jl

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
143143
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]
144144

145145
function gen_rand(r::MersenneTwister)
146-
Base.@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
146+
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
147147
mt_setfull!(r)
148148
end
149149

@@ -264,12 +264,14 @@ rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
264264

265265
#### arrays of floats
266266

267-
function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float64},
268-
region::AbstractUnitRange, I::FloatInterval_64)
269-
# region should be a subset of linearindices(A)
267+
##### AbstractArray
268+
269+
function rand!(r::MersenneTwister, A::AbstractArray{Float64},
270+
I::SamplerTrivial{<:FloatInterval_64})
271+
region = linearindices(A)
270272
# what follows is equivalent to this simple loop but more efficient:
271273
# for i=region
272-
# @inbounds A[i] = rand(r, I)
274+
# @inbounds A[i] = rand(r, I[])
273275
# end
274276
m = Base.checked_sub(first(region), 1)
275277
n = last(region)
@@ -281,53 +283,99 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
281283
end
282284
m2 = min(n, m+s)
283285
for i=m+1:m2
284-
@inbounds A[i] = rand_inbounds(r, I)
286+
@inbounds A[i] = rand_inbounds(r, I[])
285287
end
286288
m = m2
287289
end
288290
A
289291
end
290292

291-
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
292-
rand_AbstractArray_Float64!(r, A, linearindices(A), I[])
293+
294+
##### Array : internal functions
295+
296+
# internal array-like type to circumevent the lack of flexibility with reinterpret
297+
struct UnsafeView{T} <: DenseArray{T,1}
298+
ptr::Ptr{T}
299+
len::Int
300+
end
301+
302+
Base.length(a::UnsafeView) = a.len
303+
Base.getindex(a::UnsafeView, i::Int) = unsafe_load(a.ptr, i)
304+
Base.setindex!(a::UnsafeView, x, i::Int) = unsafe_store!(a.ptr, x, i)
305+
306+
# this is essentially equivalent to rand!(r, ::AbstractArray{Float64}, I) above, but due to
307+
# optimizations which can't be done currently when working with pointers, we have to re-order
308+
# manually the computation flow to get the performance
309+
# (see https://discourse.julialang.org/t/unsafe-store-sometimes-slower-than-arrays-setindex)
310+
function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInterval_64)
311+
n = length(A)
312+
@assert n <= dsfmt_get_min_array_size()+1 # == 383
313+
mt_avail(r) == 0 && gen_rand(r)
314+
# from now on, at most one call to gen_rand(r) will be necessary
315+
m = min(n, mt_avail(r))
316+
@gc_preserve r unsafe_copy!(A.ptr, pointer(r.vals, r.idx+1), m)
317+
if m == n
318+
r.idx += m
319+
else # m < n
320+
gen_rand(r)
321+
@gc_preserve r unsafe_copy!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
322+
r.idx = n-m
323+
end
324+
if I isa CloseOpen
325+
for i=1:n
326+
A[i] -= 1.0
327+
end
328+
end
329+
A
330+
end
331+
293332

294333
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
295334
dsfmt_fill_array_close_open!(s, A, n)
296335

297336
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
298337
dsfmt_fill_array_close1_open2!(s, A, n)
299338

300-
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
301-
I::FloatInterval_64)
339+
340+
function rand!(r::MersenneTwister, A::UnsafeView{Float64},
341+
I::SamplerTrivial{<:FloatInterval_64})
302342
# depending on the alignment of A, the data written by fill_array! may have
303343
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
304344
# reproducibility purposes;
305345
# so, even for well aligned arrays, fill_array! is used to generate only
306346
# the n-2 first values (or n-3 if n is odd), and the remaining values are
307347
# generated by the scalar version of rand
308-
n > length(A) && throw(BoundsError(A, n))
348+
n = length(A)
309349
n2 = (n-2) ÷ 2 * 2
310-
if n2 < dsfmt_get_min_array_size()
311-
rand_AbstractArray_Float64!(r, A, 1:n, I)
350+
n2 < dsfmt_get_min_array_size() && return _rand_max383!(r, A, I[])
351+
352+
pA = A.ptr
353+
align = Csize_t(pA) % 16
354+
if align > 0
355+
pA2 = pA + 16 - align
356+
fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted
357+
unsafe_copy!(pA, pA2, n2) # move the data to the beginning of the array
312358
else
313-
pA = pointer(A)
314-
align = Csize_t(pA) % 16
315-
Base.@gc_preserve A if align > 0
316-
pA2 = pA + 16 - align
317-
fill_array!(r.state, pA2, n2, I) # generate the data in-place, but shifted
318-
unsafe_copy!(pA, pA2, n2) # move the data to the beginning of the array
319-
else
320-
fill_array!(r.state, pA, n2, I)
321-
end
322-
for i=n2+1:n
323-
@inbounds A[i] = rand(r, I)
324-
end
359+
fill_array!(r.state, pA, n2, I[])
360+
end
361+
for i=n2+1:n
362+
A[i] = rand(r, I[])
325363
end
326364
A
327365
end
328366

329-
rand!(r::MersenneTwister, A::Array{Float64}, sp::SamplerTrivial{<:FloatInterval_64}) =
330-
_rand!(r, A, length(A), sp[])
367+
# fills up A reinterpreted as an array of Float64 with n64 values
368+
function _rand!(r::MersenneTwister, A::Array{T}, n64::Int, I::FloatInterval_64) where T
369+
# n64 is the length in terms of `Float64` of the target
370+
@assert sizeof(Float64)*n64 <= sizeof(T)*length(A) && isbits(T)
371+
@gc_preserve A rand!(r, UnsafeView{Float64}(pointer(A), n64), SamplerTrivial(I))
372+
A
373+
end
374+
375+
##### Array: Float64, Float16, Float32
376+
377+
rand!(r::MersenneTwister, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
378+
_rand!(r, A, length(A), I[])
331379

332380
mask128(u::UInt128, ::Type{Float16}) =
333381
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00
@@ -339,27 +387,27 @@ for T in (Float16, Float32)
339387
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{Close1Open2{$T}})
340388
n = length(A)
341389
n128 = n * sizeof($T) ÷ 16
342-
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
343-
2*n128, Close1Open2())
344-
# FIXME: This code is completely invalid!!!
345-
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
346-
@inbounds for i in 1:n128
347-
u = A128[i]
348-
u ⊻= u << 26
349-
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
350-
# the bit xor, are:
351-
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
352-
# the bits needing to be random are
353-
# [1:10, 17:26, 33:42, 49:58] (for Float16)
354-
# [1:23, 33:55] (for Float32)
355-
# this is obviously satisfied on the 32 low bits side, and on the high side,
356-
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
357-
# (which are discarded on the low side)
358-
# this is similar for the 64 high bits of u
359-
A128[i] = mask128(u, $T)
390+
_rand!(r, A, 2*n128, Close1Open2())
391+
@gc_preserve A begin
392+
A128 = UnsafeView{UInt128}(pointer(A), n128)
393+
for i in 1:n128
394+
u = A128[i]
395+
u ⊻= u << 26
396+
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
397+
# the bit xor, are:
398+
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
399+
# the bits needing to be random are
400+
# [1:10, 17:26, 33:42, 49:58] (for Float16)
401+
# [1:23, 33:55] (for Float32)
402+
# this is obviously satisfied on the 32 low bits side, and on the high side,
403+
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
404+
# (which are discarded on the low side)
405+
# this is similar for the 64 high bits of u
406+
A128[i] = mask128(u, $T)
407+
end
360408
end
361409
for i in 16*n128÷sizeof($T)+1:n
362-
@inbounds A[i] = rand(r, $T) + oneunit($T)
410+
@inbounds A[i] = rand(r, $T) + one($T)
363411
end
364412
A
365413
end
@@ -376,16 +424,14 @@ end
376424

377425
#### arrays of integers
378426

379-
function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
427+
function _rand!(r::MersenneTwister, A::UnsafeView{UInt128})
380428
n::Int=length(A)
381-
# FIXME: This code is completely invalid!!!
382-
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
383429
i = n
384430
while true
385-
_rand!(r, Af, 2i, Close1Open2())
431+
rand!(r, UnsafeView{Float64}(A.ptr, 2i), Close1Open2())
386432
n < 5 && break
387433
i = 0
388-
@inbounds while n-i >= 5
434+
while n-i >= 5
389435
u = A[i+=1]
390436
A[n] ⊻= u << 48
391437
A[n-=1] ⊻= u << 36
@@ -397,19 +443,17 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
397443
if n > 0
398444
u = rand(r, UInt2x52Raw())
399445
for i = 1:n
400-
@inbounds A[i] ⊻= u << (12*i)
446+
A[i] ⊻= u << (12*i)
401447
end
402448
end
403449
A
404450
end
405451

406452
for T in BitInteger_types
407-
T === UInt128 && continue
408453
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerType{$T})
409454
n = length(A)
410455
n128 = n * sizeof($T) ÷ 16
411-
# FIXME: This code is completely invalid!!!
412-
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
456+
@gc_preserve A _rand!(r, UnsafeView{UInt128}(pointer(A), n128))
413457
for i = 16*n128÷sizeof($T)+1:n
414458
@inbounds A[i] = rand(r, $T)
415459
end

base/random/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ module Random
44

55
using Base.dSFMT
66
using Base.GMP: Limb, MPZ
7-
using Base: BitInteger, BitInteger_types, BitUnsigned
7+
using Base: BitInteger, BitInteger_types, BitUnsigned, @gc_preserve
8+
89
import Base: copymutable, copy, copy!, ==, hash
910

1011
export srand,

test/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ let mt = MersenneTwister(0)
286286
@test rand!(mt, AF64)[end] == 0.957735065345398
287287
@test rand!(mt, AF64)[end] == 0.6492481059865669
288288
resize!(AF64, 2*length(mt.vals))
289-
@test Base.Random.rand_AbstractArray_Float64!(mt, AF64, linearindices(AF64), Base.Random.CloseOpen())[end] == 0.432757268470779
289+
@test invoke(rand!, Tuple{MersenneTwister,AbstractArray{Float64},Base.Random.SamplerTrivial{Base.Random.CloseOpen_64}},
290+
mt, AF64, Base.Random.SamplerTrivial(Base.Random.CloseOpen()))[end] == 0.432757268470779
290291
end
291292

292293
# Issue #9037

0 commit comments

Comments
 (0)