|
49 | 49 |
|
50 | 50 | # this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
|
51 | 51 | function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
|
52 |
| - x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s |
| 52 | + CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s |
53 | 53 | end
|
54 | 54 | function Base.replace_in_print_matrix(x::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike}, i::Integer, j::Integer, s::AbstractString)
|
55 |
| - x[i,j] ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s |
| 55 | + CUDA.@allowscalar(x[i,j]) ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s |
56 | 56 | end
|
57 | 57 |
|
58 | 58 | _onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
|
@@ -134,12 +134,13 @@ for vector `xs`. This is a sparse matrix, which stores just `Vector{UInt32}` con
|
134 | 134 | If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
|
135 | 135 | if `default` is given, else an error.
|
136 | 136 |
|
137 |
| -Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting |
138 |
| -one element from every row of `M`. |
139 |
| -
|
140 | 137 | If `xs` is a matrix, then the result is an `AbstractArray{Bool, 3}` which is one-hot along
|
141 | 138 | the first dimension, i.e. `result[:, k...] == onehot(xs[k...], labels)`.
|
142 | 139 |
|
| 140 | +Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting |
| 141 | +one element from every row of `M`. Some concatenation and reshape operations preserve onehot-ness. |
| 142 | +`OneHotArray`s can be moved to the GPU, to get for instance `OneHotMatrix(::CuArray{UInt32, 1})`. |
| 143 | +
|
143 | 144 | # Examples
|
144 | 145 | ```jldoctest
|
145 | 146 | julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')
|
|
0 commit comments