Skip to content

Commit cce7ad0

Browse files
authored
Speed up onehotbatch (#1861)
* speed up `onehotbatch` * docstrings * wording
1 parent 6b335a4 commit cce7ad0

File tree

2 files changed

+51
-9
lines changed

2 files changed

+51
-9
lines changed

src/onehot.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ and [`onecold`](@ref) to reverse either of these, as well as to generalise `argm
114114
115115
# Examples
116116
```jldoctest
117-
julia> β = Flux.onehot(:b, [:a, :b, :c])
117+
julia> β = Flux.onehot(:b, (:a, :b, :c))
118118
3-element OneHotVector(::UInt32) with eltype Bool:
119119
120120
1
@@ -131,17 +131,24 @@ julia> hcat(αβγ...) # preserves sparsity
131131
```
132132
"""
133133
function onehot(x, labels)
134-
i = something(findfirst(isequal(x), labels), 0)
135-
i > 0 || error("Value $x is not in labels")
134+
i = _findval(x, labels)
135+
isnothing(i) && error("Value $x is not in labels")
136136
OneHotVector{UInt32, length(labels)}(i)
137137
end
138138

139139
function onehot(x, labels, default)
140-
i = something(findfirst(isequal(x), labels), 0)
141-
i > 0 || return onehot(default, labels)
140+
i = _findval(x, labels)
141+
isnothing(i) && return onehot(default, labels)
142142
OneHotVector{UInt32, length(labels)}(i)
143143
end
144144

145+
_findval(val, labels) = findfirst(isequal(val), labels)
146+
# Fast unrolled method for tuples:
147+
function _findval(val, labels::Tuple, i::Integer=1)
148+
ifelse(isequal(val, first(labels)), i, _findval(val, Base.tail(labels), i+1))
149+
end
150+
_findval(val, labels::Tuple{}, i::Integer) = nothing
151+
145152
"""
146153
onehotbatch(xs, labels, [default])
147154
@@ -156,9 +163,12 @@ If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
156163
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
157164
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
158165
166+
Note that `xs` can be any iterable, such as a string. And that using a tuple
167+
for `labels` will often speed up construction, certainly for less than 32 classes.
168+
159169
# Examples
160170
```jldoctest
161-
julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')
171+
julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e')
162172
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
163173
1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
164174
⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
@@ -173,7 +183,9 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
173183
3 6 15 3 9 3 12 3 6 15 3
174184
```
175185
"""
176-
onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
186+
onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...)
187+
# NB function barier:
188+
_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
177189

178190
"""
179191
onecold(y::AbstractArray, labels = 1:size(y,1))
@@ -190,7 +202,7 @@ the same operation as `argmax(y, dims=1)` but sometimes a different return type.
190202
julia> Flux.onecold([false, true, false])
191203
2
192204
193-
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
205+
julia> Flux.onecold([0.3, 0.2, 0.5], (:a, :b, :c))
194206
:c
195207
196208
julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1

test/onehot.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,34 @@
1-
using Flux:onecold
1+
using Flux: onehot, onehotbatch, onecold
22
using Test
33

4+
@testset "onehot constructors" begin
5+
@test onehot(20, 10:10:30) == [false, true, false]
6+
@test onehot(20, (10,20,30)) == [false, true, false]
7+
@test onehot(40, (10,20,30), 20) == [false, true, false]
8+
9+
@test_throws Exception onehot('d', 'a':'c')
10+
@test_throws Exception onehot(:d, (:a, :b, :c))
11+
@test_throws Exception onehot('d', 'a':'c', 'e')
12+
@test_throws Exception onehot(:d, (:a, :b, :c), :e)
13+
14+
@test onehotbatch([20, 10], 10:10:30) == Bool[0 1; 1 0; 0 0]
15+
@test onehotbatch([20, 10], (10,20,30)) == Bool[0 1; 1 0; 0 0]
16+
@test onehotbatch([40, 10], (10,20,30), 20) == Bool[0 1; 1 0; 0 0]
17+
18+
@test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1]
19+
@test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1]
20+
21+
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c])
22+
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c))
23+
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)
24+
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e)
25+
26+
floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf)
27+
@test onecold(onehot(0.0, floats)) == 1
28+
@test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal
29+
@test onecold(onehot(Inf, floats)) == 5
30+
end
31+
432
@testset "onecold" begin
533
a = [1, 2, 5, 3.]
634
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
@@ -9,7 +37,9 @@ using Test
937
@test onecold(a) == 3
1038
@test onecold(A) == [3, 1, 4]
1139
@test onecold(a, labels) == 'C'
40+
@test onecold(a, Tuple(labels)) == 'C'
1241
@test onecold(A, labels) == ['C', 'A', 'D']
42+
@test onecold(A, Tuple(labels)) == ['C', 'A', 'D']
1343

1444
data = [:b, :a, :c]
1545
labels = [:a, :b, :c]

0 commit comments

Comments
 (0)