Skip to content

Commit 4e92c28

Browse files
committed
suggestions
1 parent 2839a2f commit 4e92c28

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

src/onehot.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,8 @@ end
5959
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
6060
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
6161
end
62-
function Base.replace_in_print_matrix(x::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike}, i::Integer, j::Integer, s::AbstractString)
63-
x[i,j] ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
64-
end
6562

63+
# copy CuArray versions back before trying to print them:
6664
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
6765
Base.print_array(io, cpu(X))
6866
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
@@ -152,12 +150,9 @@ This is a sparse matrix, which stores just a `Vector{UInt32}` containing the ind
152150
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
153151
if `default` is given, else an error.
154152
155-
If `xs` is a matrix, then the result is an `AbstractArray{Bool, 3}` which is one-hot along
156-
the first dimension, i.e. `result[:, k...] == onehot(xs[k...], labels)`.
157-
158-
Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting
159-
one element from every row of `M`. Some concatenation and reshape operations preserve onehot-ness.
160-
`OneHotArray`s can be moved to the GPU, to get for instance `OneHotMatrix(::CuArray{UInt32, 1})`.
153+
If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
154+
`AbstractArray{Bool, N+1}` which is one-hot along the first dimension,
155+
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
161156
162157
# Examples
163158
```jldoctest
@@ -169,7 +164,7 @@ julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')
169164
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
170165
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅
171166
172-
julia> reshape(1:15, 3, 5) * oh
167+
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
173168
3×11 Matrix{Int64}:
174169
1 4 13 1 7 1 10 1 4 13 1
175170
2 5 14 2 8 2 11 2 5 14 2
@@ -181,11 +176,11 @@ onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l
181176
"""
182177
onecold(y, [labels])
183178
184-
Roughly the inverse operation of [`onehot`](@ref): finds the index of
179+
Roughly the inverse operation of [`onehot`](@ref): Finds the index of
185180
the largest element of `y`, or each column of `y`, and looks them up in `labels`.
186181
187-
If `labels` are not specified, the default is integers `1:size(y,1)`,
188-
similar to `argmax(y, dims=1)`.
182+
If `labels` are not specified, the default is integers `1:size(y,1)` --
183+
the same as `argmax(y, dims=1)` but a different return type.
189184
190185
# Examples
191186
```jldoctest

0 commit comments

Comments
 (0)