Skip to content

Commit 88a8636

Browse files
evelyne-ringootvchuravymaleadt
authored
Aligning broadcasting errors between GPU and CPU (#471)
Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent cd237a4 commit 88a8636

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/host/indexing.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ function _setindex!(dest::AbstractGPUArray, src, Is...)
6868
idims = length.(Is)
6969
len = prod(idims)
7070
len==0 && return dest
71+
if length(src) != len
72+
if length(src) == 1
73+
throw(ArgumentError("indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?"))
74+
else
75+
throw(DimensionMismatch("dimensions must match: a has "*string(length(src))*" elements, b has "*string(len)))
76+
end
77+
end
7178

7279
AT = typeof(dest).name.wrapper
7380
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...

test/testsuite/indexing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ end
118118
@testset "JuliaGPU/CUDA.jl#461: sliced setindex" begin
119119
@test compare((X,Y)->(X[1,:] = Y), AT, zeros(Float32, 2,2), ones(Float32, 2))
120120
end
121+
122+
@testset "Broadcasting exceptions" for T in eltypes
123+
x = AT(zeros(T, (10, 10, 10, 10)))
124+
@test_throws ArgumentError x[1, :, :, :] = 0
125+
y = AT(rand(T, (5, 5, 5, 5)))
126+
@test_throws DimensionMismatch x[1:9,1:9,:,:] = y
127+
end
128+
121129
end
122130

123131
@testsuite "indexing find" (AT, eltypes)->begin

0 commit comments

Comments
 (0)