59
59
function Base. replace_in_print_matrix (x:: OneHotLike , i:: Integer , j:: Integer , s:: AbstractString )
60
60
x[i,j] ? s : _isonehot (x) ? Base. replace_with_centered_mark (s) : s
61
61
end
62
- function Base. replace_in_print_matrix (x:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike} , i:: Integer , j:: Integer , s:: AbstractString )
63
- x[i,j] ? s : _isonehot (parent (x)) ? Base. replace_with_centered_mark (s) : s
64
- end
65
62
63
+ # copy CuArray versions back before trying to print them:
66
64
Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:CuArray} ) where {T, L, N, var"N+1" } =
67
65
Base. print_array (io, cpu (X))
68
66
Base. print_array (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}} ) where {T, L, N, var"N+1" } =
@@ -152,12 +150,9 @@ This is a sparse matrix, which stores just a `Vector{UInt32}` containing the ind
152
150
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
153
151
if `default` is given, else an error.
154
152
155
- If `xs` is a matrix, then the result is an `AbstractArray{Bool, 3}` which is one-hot along
156
- the first dimension, i.e. `result[:, k...] == onehot(xs[k...], labels)`.
157
-
158
- Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting
159
- one element from every row of `M`. Some concatenation and reshape operations preserve onehot-ness.
160
- `OneHotArray`s can be moved to the GPU, to get for instance `OneHotMatrix(::CuArray{UInt32, 1})`.
153
+ If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
154
+ `AbstractArray{Bool, N+1}` which is one-hot along the first dimension,
155
+ i.e. `result[:, k...] == onehot(xs[k...], labels)`.
161
156
162
157
# Examples
163
158
```jldoctest
@@ -169,7 +164,7 @@ julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')
169
164
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
170
165
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
171
166
172
- julia> reshape(1:15, 3, 5) * oh
167
+ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
173
168
3×11 Matrix{Int64}:
174
169
1 4 13 1 7 1 10 1 4 13 1
175
170
2 5 14 2 8 2 11 2 5 14 2
@@ -181,11 +176,11 @@ onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l
181
176
"""
182
177
onecold(y, [labels])
183
178
184
- Roughly the inverse operation of [`onehot`](@ref): finds the index of
179
+ Roughly the inverse operation of [`onehot`](@ref): Finds the index of
185
180
the largest element of `y`, or each column of `y`, and looks them up in `labels`.
186
181
187
- If `labels` are not specified, the default is integers `1:size(y,1)`,
188
- similar to `argmax(y, dims=1)`.
182
+ If `labels` are not specified, the default is integers `1:size(y,1)` --
183
+ the same as `argmax(y, dims=1)` but a different return type .
189
184
190
185
# Examples
191
186
```jldoctest
0 commit comments