Skip to content

Commit d187fd9

Browse files
Fix return value of randn! on empty inputs (#528)
1 parent 4623226 commit d187fd9

File tree

5 files changed

+24
-2
lines changed

5 files changed

+24
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ Manifest.toml
99
.*.swp
1010
.*.swo
1111
*~
12+
13+
# MacOS generated files
14+
*.DS_Store

src/host/construction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T
1010
## convenience constructors
1111

1212
function Base.fill!(A::AnyGPUArray{T}, x) where T
13-
length(A) == 0 && return A
13+
isempty(A) && return A
1414
gpu_call(A, convert(T, x)) do ctx, a, val
1515
idx = @linearidx(a)
1616
@inbounds a[idx] = val

src/host/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ function Random.seed!(rng::RNG, seed::Vector{UInt32})
8484
end
8585

8686
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
87+
isempty(A) && return A
8788
gpu_call(A, rng.state) do ctx, a, randstates
8889
idx = linear_index(ctx)
8990
idx > length(a) && return
@@ -94,8 +95,8 @@ function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
9495
end
9596

9697
function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
98+
isempty(A) && return A
9799
threads = (length(A) - 1) ÷ 2 + 1
98-
length(A) == 0 && return
99100
gpu_call(A, rng.state; elements = threads) do ctx, a, randstates
100101
idx = 2*(linear_index(ctx) - 1) + 1
101102
U1 = gpu_rand(T, ctx, randstates)

test/testsuite/construction.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@
102102

103103
@testset "convenience" begin
104104
for T in eltypes
105+
A = AT(rand(T, 0))
106+
b = rand(T)
107+
fill!(A, b)
108+
@test A isa AT{T,1}
109+
@test Array(A) == fill(b, 0)
110+
105111
A = AT(rand(T, 3))
106112
b = rand(T)
107113
fill!(A, b)

test/testsuite/random.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
fill!(A, true)
3434
rand!(rng, A)
3535
@test false in Array(A)
36+
37+
# AT of length 0
38+
B = AT{Float32}(undef, 0)
39+
fill!(B, 1f0)
40+
rand!(rng, B)
41+
@test isempty(Array(B))
3642
end
3743

3844
@testset "randn" begin # normally-distributed
@@ -56,5 +62,11 @@
5662
randn!(cpu_rng, A)
5763
end
5864
end
65+
66+
# AT of length 0
67+
A = AT{Float32}(undef, 0)
68+
fill!(A, 1f0)
69+
randn!(rng, A)
70+
@test isempty(Array(A))
5971
end
6072
end

0 commit comments

Comments
 (0)