Skip to content

Commit 4e2d74a

Browse files
authored
Fix GPU getindex (#53)
* delete ambiguous GPU getindex methods, and add tests * restore special method for row indexing * similar and copyto to make convert(AbstractArray{Float32}, cx) work * copyto method using Adapt move to CPU * fix & test invoke case * v0.2.8
1 parent ebed6ab commit 4e2d74a

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.7"
3+
version = "0.2.8"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/array.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,9 @@ Base.size(x::OneHotArray) = (x.nlabels, size(x.indices)...)
6464

6565
function Base.getindex(x::OneHotArray{<:Any, N}, i::Int, I::Vararg{Int, N}) where N
6666
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
67-
return x.indices[I...] .== i
67+
return x.indices[I...] == i
6868
end
69-
# the method above is faster on the CPU but will scalar index on the GPU
70-
# so we define the method below to pass the extra indices directly to GPU array
71-
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray},
72-
i::Int,
73-
I::Vararg{Any, N}) where N
69+
function Base.getindex(x::OneHotArray{<:Any, N}, i::Int, I::Vararg{Any, N}) where N
7470
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
7571
return x.indices[I...] .== i
7672
end
@@ -80,6 +76,18 @@ end
8076
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
8177
Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
8278

79+
Base.similar(x::OneHotArray{<:Any,<:Any,<:Any,<:AbstractArray}, ::Type{T}, size::Base.Dims) where T =
80+
similar(x.indices, T, size)
81+
82+
function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AbstractArray}) where {T,N}
83+
size(dst) == size(src) || return invoke(copyto!, Tuple{typeof(dst), AbstractArray{Bool,N}}, dst, src)
84+
dst .= reshape(src.indices, 1, size(src.indices)...) .== (1:src.nlabels)
85+
return dst
86+
end
87+
function Base.copyto!(dst::Array{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AnyGPUArray}) where {T,N}
88+
copyto!(dst, adapt(Array, src))
89+
end
90+
8391
function Base.showarg(io::IO, x::OneHotArray, toplevel)
8492
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
8593
Base.showarg(io, x.indices, false)

test/array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ end
3838
# linear indexing
3939
@test om[11] == om[1, 2]
4040
@test oa[52] == oa[2, 1, 2]
41+
@test copyto!(rand(50,1), om) == reshape(om,:,1) # hits invoke path
42+
@test copyto!(rand(51,1), om)[1:50] == vec(om)
43+
@test_throws BoundsError copyto!(rand(49,1), om)
4144

4245
# bounds checks
4346
@test_throws BoundsError ov[0]

test/gpu.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
# Tests from Flux, probably not the optimal testset organisation!
3+
# (When CUDA is not available, these are run with JLArrays)
34

45
@testset "CUDA" begin
56
x = randn(5, 5)
@@ -18,14 +19,39 @@
1819
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
1920
end
2021

22+
@testset "gpu indexing" begin
23+
x = onehotbatch([1, 2, 3, 2], 1:3)
24+
cx = cu(x)
25+
26+
# These worked on OneHotArrays v0.2.7
27+
@test cx[:, 1:2] isa OneHotMatrix
28+
@test cx[:, 1:2].indices isa CuArray
29+
30+
@test @allowscalar cx[:,1] isa OneHotVector # column, needs @allowscalar on v0.2.7
31+
@test @allowscalar cx[:,1].indices isa Integer
32+
@test collect(@allowscalar cx[:,end]) == [0,1,0]
33+
34+
@test cx[2,:] isa CuArray{Bool} # row, is not onehot!
35+
@test sum(cx[2,:]) == 2
36+
@test collect(cx[2,:]) == x[2,:]
37+
38+
# These were broken on OneHotArrays v0.2.7
39+
@test @allowscalar cx[2,2] == x[2,2]
40+
@test collect(cx) == collect(x)
41+
@test Matrix(cx) == Matrix(x) == collect(x)
42+
@test Array{Float32}(cx) == Array{Float32}(x) == collect(x)
43+
@test convert(AbstractArray{Float32}, cx) isa CuArray{Float32}
44+
@test collect(convert(AbstractArray{Float32}, cx)) == collect(x)
45+
end
46+
2147
@testset "onehot gpu" begin
2248
y = onehotbatch(ones(3), 1:2) |> cu;
2349
@test (repr("text/plain", y); true)
2450

2551
gA = rand(3, 2) |> cu;
2652

2753
#NOTE: this would require something that can copute gradient... we don't have that here?
28-
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
54+
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
2955

3056
# some specialized implementations call only mul! and not *, so we must ensure this works
3157
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) gA*y

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ import CUDA
1919
if CUDA.functional()
2020
using CUDA # exports CuArray, etc
2121
CUDA.allowscalar(false)
22+
using CUDA: @allowscalar
2223
@info "starting CUDA tests"
2324
else
2425
@info "CUDA not functional, testing with JLArrays instead"
2526
using JLArrays # fake GPU array, for testing
2627
JLArrays.allowscalar(false)
28+
using JLArrays: @allowscalar
2729
cu = jl
2830
CuArray{T,N} = JLArray{T,N}
2931
end

0 commit comments

Comments
 (0)