Skip to content

Commit 7b851f6

Browse files
committed
GPU-friendly
1 parent a0a390e commit 7b851f6

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/onehot.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ end
4949

5050
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
5151
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
52-
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
52+
CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
5353
end
5454
function Base.replace_in_print_matrix(x::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike}, i::Integer, j::Integer, s::AbstractString)
55-
x[i,j] ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
55+
CUDA.@allowscalar(x[i,j]) ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
5656
end
5757

5858
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
@@ -134,12 +134,13 @@ for vector `xs`. This is a sparse matrix, which stores just `Vector{UInt32}` con
134134
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
135135
if `default` is given, else an error.
136136
137-
Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting
138-
one element from every row of `M`.
139-
140137
If `xs` is a matrix, then the result is an `AbstractArray{Bool, 3}` which is one-hot along
141138
the first dimension, i.e. `result[:, k...] == onehot(xs[k...], labels)`.
142139
140+
Matrix multiplication `M * onehotbatch(...)` is performed efficiently, by simply getting
141+
one element from every row of `M`. Some concatenation and reshape operations preserve onehot-ness.
142+
`OneHotArray`s can be moved to the GPU, to get for instance `OneHotMatrix(::CuArray{UInt32, 1})`.
143+
143144
# Examples
144145
```jldoctest
145146
julia> oh = Flux.onehotbatch(collect("abracadabra"), 'a':'e', 'e')

0 commit comments

Comments
 (0)