1
1
import Adapt
2
2
import . CUDA
3
3
4
+ """
5
+ OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
6
+
7
+ These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
8
+ Parameter `I` is the type of the underlying storage, and `T` its eltype.
9
+ """
4
10
struct OneHotArray{T<: Integer , L, N, var"N+1" , I<: Union{T, AbstractArray{T, N}} } <: AbstractArray{Bool, var"N+1"}
5
11
indices:: I
6
12
end
@@ -15,7 +21,9 @@ _indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
15
21
const OneHotVector{T, L} = OneHotArray{T, L, 0 , 1 , T}
16
22
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1 , 2 , I}
17
23
24
+ @doc @doc (OneHotArray)
18
25
OneHotVector (idx, L) = OneHotArray (idx, L)
26
+ @doc @doc (OneHotArray)
19
27
OneHotMatrix (indices, L) = OneHotArray (indices, L)
20
28
21
29
# use this type so reshaped arrays hit fast paths
49
57
50
58
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
51
59
function Base. replace_in_print_matrix (x:: OneHotLike , i:: Integer , j:: Integer , s:: AbstractString )
52
- CUDA. @allowscalar (x[i,j]) ? s : _isonehot (x) ? Base. replace_with_centered_mark (s) : s
60
+ # CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
61
+ x[i,j] ? s : _isonehot (x) ? Base. replace_with_centered_mark (s) : s
53
62
end
54
63
function Base. replace_in_print_matrix (x:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike} , i:: Integer , j:: Integer , s:: AbstractString )
55
64
CUDA. @allowscalar (x[i,j]) ? s : _isonehot (parent (x)) ? Base. replace_with_centered_mark (s) : s
56
65
end
57
66
67
+ # Base.show(io::IO, x::OneHotLike) = show(io, convert(Array{Bool}, cpu(x))) # helps string(cu(y))
68
+
69
+ Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:CuArray} ) where {T, L, N, var"N+1" } =
70
+ Base. print_array (io, cpu (X))
71
+
58
72
_onehot_bool_type (x:: OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
59
73
_onehot_bool_type (x:: OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
60
74
@@ -102,26 +116,31 @@ and [`onecold`](@ref) for a `labels`-aware `argmax`.
102
116
103
117
# Examples
104
118
```jldoctest
105
- julia> Flux.onehot(:b, [:a, :b, :c])
119
+ julia> β = Flux.onehot(:b, [:a, :b, :c])
106
120
3-element OneHotVector(::UInt32) with eltype Bool:
107
121
⋅
108
122
1
109
123
⋅
110
124
111
- julia> Flux.onehot(-33, 0:19, 0)' # uses default
112
- 1×20 adjoint(OneHotVector(::UInt32)) with eltype Bool:
113
- 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
125
+ julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default
126
+ (Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
127
+
128
+ julia> hcat(αβγ...) # preserves sparsity
129
+ 3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
130
+ 1 ⋅ ⋅
131
+ ⋅ 1 ⋅
132
+ ⋅ ⋅ 1
114
133
```
115
134
"""
116
- function onehot (l , labels)
117
- i = something (findfirst (isequal (l ), labels), 0 )
118
- i > 0 || error (" Value $l is not in labels" )
135
+ function onehot (x , labels)
136
+ i = something (findfirst (isequal (x ), labels), 0 )
137
+ i > 0 || error (" Value $x is not in labels" )
119
138
OneHotVector {UInt32, length(labels)} (i)
120
139
end
121
140
122
- function onehot (l , labels, unk )
123
- i = something (findfirst (isequal (l ), labels), 0 )
124
- i > 0 || return onehot (unk , labels)
141
+ function onehot (x , labels, default )
142
+ i = something (findfirst (isequal (x ), labels), 0 )
143
+ i > 0 || return onehot (default , labels)
125
144
OneHotVector {UInt32, length(labels)} (i)
126
145
end
127
146
@@ -163,16 +182,16 @@ onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l
163
182
"""
164
183
onecold(y, [labels])
165
184
166
- Roughly the inverse operations of [`onehot`](@ref): finds the index of
185
+ Roughly the inverse operation of [`onehot`](@ref): finds the index of
167
186
the largest element of `y`, or each column of `y`, and looks them up in `labels`.
168
187
169
188
If `labels` are not specified, the default is integers `1:size(y,1)`,
170
189
similar to `argmax(y, dims=1)`.
171
190
172
191
# Examples
173
192
```jldoctest
174
- julia> Flux.onecold([true, false, false], [:a, :b, :c ])
175
- :a
193
+ julia> Flux.onecold([false, true, false ])
194
+ 2
176
195
177
196
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
178
197
:c
0 commit comments