Skip to content

Commit 658effc

Browse files
committed
update Embedding to use gather for AbstractVector
1 parent 2398dfc commit 658effc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/layers/basic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ function Embedding(in::Integer, out::Integer;
418418
end
419419

420420
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
421-
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
421+
(m::Embedding)(x::Int) = m.weight[:, x]
422+
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
422423
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
423424

424425
function Base.show(io::IO, m::Embedding)

0 commit comments

Comments
 (0)