Skip to content

Commit fe143f7

Browse files
committed
Don't call rand without RNG, and add Float16 support.
1 parent cfe1b6d commit fe143f7

File tree

4 files changed

+37
-44
lines changed

4 files changed

+37
-44
lines changed

src/host/random.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ LCGStep(z::Unsigned, A::Unsigned, C::Unsigned) = A * z + C
1616

1717
make_rand_num(::Type{Float64}, tmp) = 2.3283064365387e-10 * Float64(tmp)
1818
make_rand_num(::Type{Float32}, tmp) = 2.3283064f-10 * Float32(tmp)
19+
# NOTE: the rng state is often not representable as Float16, so perform this in Float32.
20+
make_rand_num(::Type{Float16}, tmp) = Float16(2.3283064f-10 * Float32(tmp))
1921

2022
function next_rand(state::NTuple{4, T}) where {T <: Unsigned}
2123
state = (

test/jlarray.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,21 @@ Base.copyto!(dest::DenseJLArray{T}, source::DenseJLArray{T}) where {T} =
322322
copyto!(dest, 1, source, 1, length(source))
323323

324324

325-
## Random
325+
## random number generation
326326

327327
using Random
328328

329-
# JLArray only supports generating random numbers with the GPUArrays RNG
330-
Random.rand!(A::AnyJLArray) = Random.rand!(GPUArrays.default_rng(JLArray), A)
331-
Random.randn!(A::AnyJLArray) = Random.randn!(GPUArrays.default_rng(JLArray), A)
329+
const GLOBAL_RNG = Ref{Union{Nothing,GPUArrays.RNG}}(nothing)
330+
function GPUArrays.default_rng(::Type{<:JLArray})
331+
if GLOBAL_RNG[] === nothing
332+
N = MAXTHREADS
333+
state = JLArray{NTuple{4, UInt32}}(undef, N)
334+
rng = GPUArrays.RNG(state)
335+
Random.seed!(rng)
336+
GLOBAL_RNG[] = rng
337+
end
338+
GLOBAL_RNG[]
339+
end
332340

333341

334342
## GPUArrays interfaces
@@ -346,16 +354,4 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
346354
@allowscalar Base.reducedim!(op, R.data, map(f, A))
347355
end
348356

349-
const GLOBAL_RNG = Ref{Union{Nothing,GPUArrays.RNG}}(nothing)
350-
function GPUArrays.default_rng(::Type{<:JLArray})
351-
if GLOBAL_RNG[] === nothing
352-
N = MAXTHREADS
353-
state = JLArray{NTuple{4, UInt32}}(undef, N)
354-
rng = GPUArrays.RNG(state)
355-
Random.seed!(rng)
356-
GLOBAL_RNG[] = rng
357-
end
358-
GLOBAL_RNG[]
359-
end
360-
361357
end

test/testsuite/linalg.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ end
4747
gpu_a = AT{Float32}(undef, 128, 128)
4848
gpu_b = AT{Float32}(undef, 128, 128)
4949

50-
rand!(gpu_a)
50+
copyto!(gpu_a, rand(Float32, (128,128)))
5151
copyto!(cpu_a, TR(gpu_a))
5252
@test cpu_a == Array(collect(TR(gpu_a)))
5353

54-
rand!(gpu_a)
54+
copyto!(gpu_a, rand(Float32, (128,128)))
5555
gpu_c = copyto!(gpu_b, TR(gpu_a))
5656
@test all(Array(gpu_b) .== TR(Array(gpu_a)))
5757
@test gpu_c isa AT
5858
end
5959

6060
for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
61-
gpu_a = AT{Float32}(undef, 128, 128) |> rand! |> TR
61+
gpu_a = AT(rand(Float32, (128,128))) |> TR
6262
gpu_b = AT{Float32}(undef, 128, 128) |> TR
6363

6464
gpu_c = copyto!(gpu_b, gpu_a)
@@ -72,10 +72,8 @@ end
7272
@testsuite "linalg/diagonal" AT->begin
7373
@testset "Array + Diagonal" begin
7474
n = 128
75-
A = AT{Float32}(undef, n, n)
76-
d = AT{Float32}(undef, n)
77-
rand!(A)
78-
rand!(d)
75+
A = AT(rand(Float32, (n,n)))
76+
d = AT(rand(Float32, n))
7977
D = Diagonal(d)
8078
B = A + D
8179
@test collect(B) collect(A) + collect(D)

test/testsuite/random.jl

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
@testsuite "random" AT->begin
2+
rng = if AT <: AbstractGPUArray
3+
GPUArrays.default_rng(AT)
4+
else
5+
Random.default_rng()
6+
end
7+
28
@testset "rand" begin # uniform
3-
for T in (Int8, Float32, Float64, Int64, Int32,
4-
Complex{Float32}, Complex{Float64},
5-
Complex{Int64}, Complex{Int32}), d in (10, (10,10))
9+
for T in (Int16, Int32, Int64,
10+
Float16, Float32, Float64,
11+
Complex{Float16}, Complex{Float32}, Complex{Float64},
12+
Complex{Int32}, Complex{Int64}), d in (10, (10,10))
613
A = AT{T}(undef, d)
714
B = copy(A)
8-
rand!(A)
9-
rand!(B)
15+
rand!(rng, A)
16+
rand!(rng, B)
1017
@test Array(A) != Array(B)
1118

12-
rng = if AT <: AbstractGPUArray
13-
GPUArrays.default_rng(AT)
14-
else
15-
Random.default_rng()
16-
end
1719
Random.seed!(rng)
1820
Random.seed!(rng, 1)
1921
rand!(rng, A)
@@ -24,26 +26,21 @@
2426

2527
A = AT{Bool}(undef, 1024)
2628
fill!(A, false)
27-
rand!(A)
29+
rand!(rng, A)
2830
@test true in Array(A)
2931
fill!(A, true)
30-
rand!(A)
32+
rand!(rng, A)
3133
@test false in Array(A)
3234
end
3335

34-
@testset "randn" begin # uniform
35-
for T in (Float32, Float64), d in (2, (2,2))
36+
@testset "randn" begin # normally-distributed
37+
for T in (Float16, Float32, Float64), d in (2, (2,2))
3638
A = AT{T}(undef, d)
3739
B = copy(A)
38-
randn!(A)
39-
randn!(B)
40+
randn!(rng, A)
41+
randn!(rng, B)
4042
@test !any(Array(A) .== Array(B))
4143

42-
rng = if AT <: AbstractGPUArray
43-
GPUArrays.default_rng(AT)
44-
else
45-
Random.default_rng()
46-
end
4744
Random.seed!(rng)
4845
Random.seed!(rng, 1)
4946
randn!(rng, A)

0 commit comments

Comments
 (0)