Skip to content

Commit 8f447ff

Browse files
mcabbottToucheSir
andauthored
Faster path for onehotbatch(::CUArray{Int}, ::UnitRange) (#29)
* faster path for GPU creation * fixup * skip tests with real CUDA * indent Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * rm comment Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent 32e06c8 commit 8f447ff

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.2"
3+
version = "0.2.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/onehot.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,24 @@ function _onehotbatch(data, labels, default)
101101
end
102102

103103
function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
104-
# lo, hi = extrema(data) # fails on Julia 1.6
105-
lo, hi = minimum(data), maximum(data)
104+
lo, hi = extrema(data)
106105
lo < first(labels) && error("Value $lo not found in labels")
107106
hi > last(labels) && error("Value $hi not found in labels")
108107
offset = 1 - first(labels)
109108
indices = UInt32.(data .+ offset)
110109
return OneHotArray(indices, length(labels))
111110
end
111+
# That bounds check with extrema synchronises on GPU, much slower than rest of the function,
112+
# hence add a special method, with a less helpful error message:
113+
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
114+
offset = 1 - first(labels)
115+
indices = map(data) do datum
116+
i = UInt32(datum + offset)
117+
checkbounds(labels, i)
118+
i
119+
end
120+
return OneHotArray(indices, length(labels))
121+
end
112122

113123
"""
114124
onecold(y::AbstractArray, labels = 1:size(y,1))

test/gpu.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ end
3030
y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu
3131
y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9)
3232
@test y1.indices == y2.indices
33-
@test_broken y1 == y2
33+
@test_broken y1 == y2 # issue 28
3434

35-
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
36-
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
35+
if !CUDA.functional()
36+
# Here CUDA gives an error which @test_throws does not notice,
37+
# although with JLArrays @test_throws it's fine.
38+
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
39+
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
40+
end
3741
end
3842

3943
@testset "onecold gpu" begin

0 commit comments

Comments
 (0)