Skip to content

Commit 15329c4

Browse files
committed
tweaks
1 parent 1a57f7d commit 15329c4

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

src/onehot.jl

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import Adapt
22
import .CUDA
33

4+
"""
5+
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
6+
7+
These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
8+
Parameter `I` is the type of the underlying storage, and `T` its eltype.
9+
"""
410
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
511
indices::I
612
end
@@ -15,7 +21,9 @@ _indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
1521
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
1622
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
1723

24+
@doc @doc(OneHotArray)
1825
OneHotVector(idx, L) = OneHotArray(idx, L)
26+
@doc @doc(OneHotArray)
1927
OneHotMatrix(indices, L) = OneHotArray(indices, L)
2028

2129
# use this type so reshaped arrays hit fast paths
@@ -49,12 +57,18 @@ end
4957

5058
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
5159
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
52-
CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
60+
# CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
61+
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
5362
end
5463
function Base.replace_in_print_matrix(x::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike}, i::Integer, j::Integer, s::AbstractString)
5564
CUDA.@allowscalar(x[i,j]) ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
5665
end
5766

67+
# Base.show(io::IO, x::OneHotLike) = show(io, convert(Array{Bool}, cpu(x))) # helps string(cu(y))
68+
69+
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
70+
Base.print_array(io, cpu(X))
71+
5872
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
5973
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
6074

@@ -102,26 +116,31 @@ and [`onecold`](@ref) for a `labels`-aware `argmax`.
102116
103117
# Examples
104118
```jldoctest
105-
julia> Flux.onehot(:b, [:a, :b, :c])
119+
julia> β = Flux.onehot(:b, [:a, :b, :c])
106120
3-element OneHotVector(::UInt32) with eltype Bool:
107121
108122
1
109123
110124
111-
julia> Flux.onehot(-33, 0:19, 0)' # uses default
112-
1×20 adjoint(OneHotVector(::UInt32)) with eltype Bool:
113-
1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
125+
julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default
126+
(Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
127+
128+
julia> hcat(αβγ...) # preserves sparsity
129+
3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
130+
1 ⋅ ⋅
131+
⋅ 1 ⋅
132+
⋅ ⋅ 1
114133
```
115134
"""
116-
function onehot(l, labels)
117-
i = something(findfirst(isequal(l), labels), 0)
118-
i > 0 || error("Value $l is not in labels")
135+
function onehot(x, labels)
136+
i = something(findfirst(isequal(x), labels), 0)
137+
i > 0 || error("Value $x is not in labels")
119138
OneHotVector{UInt32, length(labels)}(i)
120139
end
121140

122-
function onehot(l, labels, unk)
123-
i = something(findfirst(isequal(l), labels), 0)
124-
i > 0 || return onehot(unk, labels)
141+
function onehot(x, labels, default)
142+
i = something(findfirst(isequal(x), labels), 0)
143+
i > 0 || return onehot(default, labels)
125144
OneHotVector{UInt32, length(labels)}(i)
126145
end
127146

@@ -163,16 +182,16 @@ onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l
163182
"""
164183
onecold(y, [labels])
165184
166-
Roughly the inverse operations of [`onehot`](@ref): finds the index of
185+
Roughly the inverse operation of [`onehot`](@ref): finds the index of
167186
the largest element of `y`, or each column of `y`, and looks them up in `labels`.
168187
169188
If `labels` are not specified, the default is integers `1:size(y,1)`,
170189
similar to `argmax(y, dims=1)`.
171190
172191
# Examples
173192
```jldoctest
174-
julia> Flux.onecold([true, false, false], [:a, :b, :c])
175-
:a
193+
julia> Flux.onecold([false, true, false])
194+
2
176195
177196
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
178197
:c

0 commit comments

Comments
 (0)