We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2398dfc commit 658effcCopy full SHA for 658effc
src/layers/basic.jl
@@ -418,7 +418,8 @@ function Embedding(in::Integer, out::Integer;
418
end
419
420
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
421
-(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
+(m::Embedding)(x::Int) = m.weight[:, x]
422
+(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
423
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
424
425
function Base.show(io::IO, m::Embedding)
0 commit comments