diff --git a/src/Flux.jl b/src/Flux.jl index 5f906a0528..a38115f1ef 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,7 +14,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient using ChainRulesCore -export Chain, Dense, Maxout, SkipConnection, Parallel, +export Chain, Dense, Maxout, SkipConnection, Parallel, Embedding, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 952ff7d444..9f47cba309 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -483,7 +483,8 @@ function Base.show(io::IO, m::Parallel) end """ - Embedding(in => out; init=randn) + Embedding(in => out; init=randn32) + Embedding(weight::AbstractMatrix) A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. @@ -493,41 +494,39 @@ The input to the layer can be either a vector of indexes or the corresponding [onehot encoding](@ref Flux.OneHotArray). # Examples -```jldoctest -julia> vocab_size, embed_size = 1000, 4; - -julia> model = Flux.Embedding(vocab_size => embed_size) -Embedding(1000 => 4) # 4_000 parameters -julia> vocab_idxs = [1, 722, 53, 220, 3]; +```jldoctest +julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0) +Embedding(26 => 2) -julia> x = Flux.OneHotMatrix(vocab_idxs, vocab_size); summary(x) -"1000×5 OneHotMatrix(::Vector{Int64}) with eltype Bool" +julia> m(5) # embedding vector for 5th element +2-element Vector{Float32}: + 2.01 + 3.01 -julia> model(x) |> summary -"4×5 Matrix{Float32}" +julia> m([6, 15, 15]) # applied to a batch +2×3 Matrix{Float32}: + 4.01 22.01 22.01 + 5.01 23.01 23.01 -julia> model(vocab_idxs) == model(x) +julia> ans == m(Flux.onehotbatch("foo", 'a':'z')) true ``` """ -struct Embedding{W} +struct Embedding{W <: AbstractMatrix} weight::W end @functor Embedding -Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) +Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims))) (m::Embedding)(x::Integer) = m.weight[:, x] (m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) - -function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(onecold(x)) -end +(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) +(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix function Base.show(io::IO, m::Embedding) - print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") + print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))") end diff --git a/src/onehot.jl b/src/onehot.jl index 86afd513dc..814645c6d2 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -33,6 +33,9 @@ const OneHotLike{T, L, N, var"N+1", I} = Union{OneHotArray{T, L, N, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}} +const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T} +const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I} + _isonehot(x::OneHotArray) = true _isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L) diff --git a/src/outputsize.jl b/src/outputsize.jl index 1caea9e16b..02c92c7174 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -168,3 +168,7 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) end end end + +function NNlib.gather(src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{<:Nil}) where {Tsrc, Nsrc} + fill(nil, (size(src)[1:Nsrc-1]..., size(idx)...)) +end \ No newline at end of file diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 396e6c0ab5..3c28c6a268 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -127,13 +127,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] -gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2) -gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2) -gpu_gradtest("Embedding integer index", embedding, 1, 5, 2) -gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2) -gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2) +gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2) +gpu_gradtest("Embedding repeated indices", embedding, rand(1:10, 10^3), 10 => 2) +gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2) +gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2) +gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2) +gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:10, 10^3), 10), 10 => 2) @testset "function layers" begin x = rand(Float32, 3,3) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index ca8e15a643..8ad58b26f1 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -276,7 +276,7 @@ import Flux: activations @testset "Embedding" begin vocab_size, embed_size = 10, 4 - m = Flux.Embedding(vocab_size, embed_size) + m = Flux.Embedding(vocab_size => embed_size) @test size(m.weight) == (embed_size, vocab_size) x = rand(1:vocab_size, 3) @@ -284,9 +284,9 @@ import Flux: activations @test y isa Matrix{Float32} @test y ≈ m.weight[:,x] x2 = OneHotMatrix(x, vocab_size) - y2 = m(x2) - @test y2 isa Matrix{Float32} - @test y2 ≈ y + @test m(x2) isa Matrix{Float32} + @test m(x2) ≈ y + @test m(collect(x2)) ≈ y @test_throws DimensionMismatch m(OneHotMatrix(x, 1000)) x = rand(1:vocab_size, 3, 4) @@ -297,6 +297,9 @@ import Flux: activations @test m(2) ≈ m.weight[:,2] @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] @test_throws DimensionMismatch m(OneHotVector(3, 1000)) + + x = onehotbatch(rand(1:vocab_size, 4, 3, 4, 5), 1:vocab_size) + @test m(x) ≈ m(onecold(x)) end end diff --git a/test/outputsize.jl b/test/outputsize.jl index 2c90811dcb..d1ae94223a 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -155,3 +155,10 @@ end @test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16) @test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1) end + +@testset "embedding" begin + m = Embedding(3=>5) + @test outputsize(m, (2,)) == (5, 2) + @test outputsize(m, (2, 3)) == (5, 2, 3) + @test outputsize(m, (2, 3, 4)) == (5, 2, 3, 4) +end