Skip to content

Commit 32e06c8

Browse files
authored
Fast path onehotbatch(::Vector{Int}, ::UnitRange) (#27)
* add a fast path * add an error check * fixup, add tests * fix 1.6
1 parent d27d037 commit 32e06c8

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.1"
3+
version = "0.2.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/onehot.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@ function _onehotbatch(data, labels, default)
100100
return OneHotArray(indices, length(labels))
101101
end
102102

103+
function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
104+
# lo, hi = extrema(data) # fails on Julia 1.6
105+
lo, hi = minimum(data), maximum(data)
106+
lo < first(labels) && error("Value $lo not found in labels")
107+
hi > last(labels) && error("Value $hi not found in labels")
108+
offset = 1 - first(labels)
109+
indices = UInt32.(data .+ offset)
110+
return OneHotArray(indices, length(labels))
111+
end
112+
103113
"""
104114
onecold(y::AbstractArray, labels = 1:size(y,1))
105115

test/gpu.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ end
2626
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
2727
end
2828

29+
@testset "onehotbatch(::CuArray, ::UnitRange)" begin
30+
y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu
31+
y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9)
32+
@test y1.indices == y2.indices
33+
@test_broken y1 == y2
34+
35+
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
36+
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
37+
end
38+
2939
@testset "onecold gpu" begin
3040
y = onehotbatch(ones(3), 1:10) |> cu;
3141
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']

test/onehot.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
@test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal
2828
@test onecold(onehot(Inf, floats)) == 5
2929

30+
# UnitRange fast path
31+
@test onehotbatch([1,3,0,4], 0:4) == onehotbatch([1,3,0,4], Tuple(0:4))
32+
@test onehotbatch([2 3 7 4], 2:7) == onehotbatch([2 3 7 4], Tuple(2:7))
33+
@test_throws Exception onehotbatch([2, -1], 0:4)
34+
@test_throws Exception onehotbatch([2, 5], 0:4)
35+
3036
# inferrabiltiy tests
3137
@test @inferred(onehot(20, 10:10:30)) == [false, true, false]
3238
@test @inferred(onehot(40, (10,20,30), 20)) == [false, true, false]

0 commit comments

Comments
 (0)