Skip to content

Commit 4c242f1

Browse files
committed
add outputsize special case for NNlib.gather
1 parent fe100d5 commit 4c242f1

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(las
470470

471471

472472
(m::Embedding)(x::Integer) = m.weight[:, x]
473-
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
473+
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
474474
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
475475
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
476476
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix

src/outputsize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
169169
end
170170
end
171171

172-
(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
172+
NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...)

0 commit comments

Comments
 (0)