@@ -114,7 +114,7 @@ and [`onecold`](@ref) to reverse either of these, as well as to generalise `argm
114
114
115
115
# Examples
116
116
```jldoctest
117
- julia> β = Flux.onehot(:b, [ :a, :b, :c] )
117
+ julia> β = Flux.onehot(:b, ( :a, :b, :c) )
118
118
3-element OneHotVector(::UInt32) with eltype Bool:
119
119
⋅
120
120
1
@@ -131,17 +131,24 @@ julia> hcat(αβγ...) # preserves sparsity
131
131
```
132
132
"""
133
133
function onehot (x, labels)
134
- i = something ( findfirst ( isequal (x) , labels), 0 )
135
- i > 0 || error (" Value $x is not in labels" )
134
+ i = _findval (x , labels)
135
+ isnothing (i) && error (" Value $x is not in labels" )
136
136
OneHotVector {UInt32, length(labels)} (i)
137
137
end
138
138
139
139
function onehot (x, labels, default)
140
- i = something ( findfirst ( isequal (x) , labels), 0 )
141
- i > 0 || return onehot (default, labels)
140
+ i = _findval (x , labels)
141
+ isnothing (i) && return onehot (default, labels)
142
142
OneHotVector {UInt32, length(labels)} (i)
143
143
end
144
144
145
+ _findval (val, labels) = findfirst (isequal (val), labels)
146
+ # Fast unrolled method for tuples:
147
+ function _findval (val, labels:: Tuple , i:: Integer = 1 )
148
+ ifelse (isequal (val, first (labels)), i, _findval (val, Base. tail (labels), i+ 1 ))
149
+ end
150
+ _findval (val, labels:: Tuple{} , i:: Integer ) = nothing
151
+
145
152
"""
146
153
onehotbatch(xs, labels, [default])
147
154
@@ -156,9 +163,12 @@ If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
156
163
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
157
164
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
158
165
166
+ Note that `xs` can be any iterable, such as a string. And that using a tuple
167
+ for `labels` will often speed up construction, certainly for less than 32 classes.
168
+
159
169
# Examples
160
170
```jldoctest
161
- julia> oh = Flux.onehotbatch(collect( "abracadabra") , 'a':'e', 'e')
171
+ julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e')
162
172
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
163
173
1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
164
174
⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
@@ -173,7 +183,9 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
173
183
3 6 15 3 9 3 12 3 6 15 3
174
184
```
175
185
"""
176
- onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
186
+ onehotbatch (ls, labels, default... ) = _onehotbatch (ls, length (labels) < 32 ? Tuple (labels) : labels, default... )
187
+ # NB function barier:
188
+ _onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
177
189
178
190
"""
179
191
onecold(y::AbstractArray, labels = 1:size(y,1))
@@ -190,7 +202,7 @@ the same operation as `argmax(y, dims=1)` but sometimes a different return type.
190
202
julia> Flux.onecold([false, true, false])
191
203
2
192
204
193
- julia> Flux.onecold([0.3, 0.2, 0.5], [ :a, :b, :c] )
205
+ julia> Flux.onecold([0.3, 0.2, 0.5], ( :a, :b, :c) )
194
206
:c
195
207
196
208
julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1
0 commit comments