Skip to content

Commit 9d6bed7

Browse files
committed
updated and exported Embedding
1 parent 1bc9421 commit 9d6bed7

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using MacroTools: @forward
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

13-
export Chain, Dense, Maxout, SkipConnection, Parallel,
13+
export Chain, Dense, Maxout, SkipConnection, Parallel, Embedding,
1414
RNN, LSTM, GRU, GRUv3,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,

src/layers/basic.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,13 +515,10 @@ Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(las
515515

516516

517517
(m::Embedding)(x::Integer) = m.weight[:, x]
518-
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
518+
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
519519
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
520-
521-
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
522-
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
523-
return m(onecold(x))
524-
end
520+
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
521+
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix
525522

526523
function Base.show(io::IO, m::Embedding)
527524
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")

src/outputsize.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
169169
end
170170
end
171171

172-
(m::Embedding)(x::AbstractVector{<:Nil}) = fill(nil, size(m.weight, 1), length(x))
173-
(m::Embedding)(x::AbstractArray{<:Nil}) = fill(nil, size(m.weight, 1), size(x)...)
172+
(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x))

0 commit comments

Comments
 (0)