@@ -644,14 +644,17 @@ function Base.show(io::IO, m::PairwiseFusion)
644
644
end
645
645
646
646
"""
647
- Embedding(in => out; init=randn )
647
+ Embedding(in => out; init=randn32 )
648
648
649
649
A lookup table that stores embeddings of dimension `out`
650
- for a vocabulary of size `in`.
650
+ for a vocabulary of size `in`, as a trainable matrix .
651
651
652
652
This layer is often used to store word embeddings and retrieve them using indices.
653
- The input to the layer can be either a vector of indexes
654
- or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653
+ The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654
+ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655
+
656
+ For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657
+ For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`.
655
658
656
659
# Examples
657
660
```jldoctest
@@ -684,10 +687,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684
687
(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
685
688
(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
686
689
687
- function (m:: Embedding )(x:: Union{OneHotVector{T,L}, OneHotMatrix{T,L}} ) where {T,L}
688
- size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
689
- return m (onecold (x))
690
- end
690
+ (m:: Embedding )(x:: AbstractVector{Bool} ) = m. weight * x # usually OneHotVector
691
+ (m:: Embedding )(x:: AbstractMatrix{Bool} ) = m. weight * x # usually OneHotMatrix
692
+ (m:: Embedding )(x:: AbstractArray{Bool} ) = reshape (m (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
691
693
692
694
function Base. show (io:: IO , m:: Embedding )
693
695
print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
0 commit comments