Skip to content

Commit b84c56d

Browse files
committed
updated and exported Embedding
1 parent a6ff89f commit b84c56d

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using MacroTools: @forward
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

13-
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
13+
export Chain, Dense, Maxout, SkipConnection, Parallel, Embedding, flatten,
1414
RNN, LSTM, GRU,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,

src/layers/basic.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,13 +470,10 @@ Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(las
470470

471471

472472
(m::Embedding)(x::Integer) = m.weight[:, x]
473-
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
473+
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
474474
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
475-
476-
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
477-
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
478-
return m(onecold(x))
479-
end
475+
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
476+
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix
480477

481478
function Base.show(io::IO, m::Embedding)
482479
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")

src/outputsize.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
169169
end
170170
end
171171

172-
(m::Embedding)(x::AbstractVector{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
173-
(m::Embedding)(x::AbstractArray{<:Nil}) = fill(nil, size(m.weight, 1), size(x)...)
172+
(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x))

0 commit comments

Comments
 (0)