From cb3a4ca7d63816e641609555c26defb42a3196c3 Mon Sep 17 00:00:00 2001 From: Manikya Date: Tue, 6 Jul 2021 21:30:02 +0530 Subject: [PATCH 1/9] Embedding special case for outputsize --- src/outputsize.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/outputsize.jl b/src/outputsize.jl index 1caea9e16b..3613502671 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -168,3 +168,6 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) end end end + +(m::Embedding)(x::AbstractVector{<:Nil}) = fill(nil, size(m.weight, 1), length(x)) +(m::Embedding)(x::AbstractArray{<:Nil}) = fill(nil, size(m.weight, 1), size(x)...) From eb489c3b21eb69d54e1742b787630476b8daeae2 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Sun, 11 Jul 2021 02:16:33 +0530 Subject: [PATCH 2/9] Apply suggestions from code review Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Kyle Daruwalla --- src/layers/basic.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 952ff7d444..4f0d0a50f8 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -493,25 +493,26 @@ The input to the layer can be either a vector of indexes or the corresponding [onehot encoding](@ref Flux.OneHotArray). # Examples -```jldoctest -julia> vocab_size, embed_size = 1000, 4; - -julia> model = Flux.Embedding(vocab_size => embed_size) -Embedding(1000 => 4) # 4_000 parameters -julia> vocab_idxs = [1, 722, 53, 220, 3]; +```jldoctest +julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0) +Embedding(26 => 2) -julia> x = Flux.OneHotMatrix(vocab_idxs, vocab_size); summary(x) -"1000×5 OneHotMatrix(::Vector{Int64}) with eltype Bool" +julia> m(5) # embedding vector for 5th element +2-element Vector{Float32}: + 2.01 + 3.01 -julia> model(x) |> summary -"4×5 Matrix{Float32}" +julia> m([6, 15, 15]) # applied to a batch +2×3 Matrix{Float32}: + 4.01 22.01 22.01 + 5.01 23.01 23.01 -julia> model(vocab_idxs) == model(x) +julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26)) true ``` """ -struct Embedding{W} +struct Embedding{W <: AbstractMatrix} weight::W end @@ -529,5 +530,5 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T end function Base.show(io::IO, m::Embedding) - print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") + print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))") end From f702260781279fbff3ff121ca70c887c2719c1ac Mon Sep 17 00:00:00 2001 From: Manikya Date: Mon, 12 Jul 2021 22:18:26 +0530 Subject: [PATCH 3/9] update Embedding constructor Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts. --- src/layers/basic.jl | 2 +- src/onehot.jl | 3 +++ test/cuda/layers.jl | 14 +++++++------- test/layers/basic.jl | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4f0d0a50f8..ca58407f23 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -518,7 +518,7 @@ end @functor Embedding -Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) +Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims))) (m::Embedding)(x::Integer) = m.weight[:, x] (m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) diff --git a/src/onehot.jl b/src/onehot.jl index 86afd513dc..814645c6d2 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -33,6 +33,9 @@ const OneHotLike{T, L, N, var"N+1", I} = Union{OneHotArray{T, L, N, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}} +const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T} +const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I} + _isonehot(x::OneHotArray) = true _isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 396e6c0ab5..13a3830fb6 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -127,13 +127,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] -gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2) -gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2) -gpu_gradtest("Embedding integer index", embedding, 1, 5, 2) -gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2) -gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2) +gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2) +gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2) +gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2) +gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2) +gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2) +gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2) @testset "function layers" begin x = rand(Float32, 3,3) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index ca8e15a643..53b6935d0c 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -276,7 +276,7 @@ import Flux: activations @testset "Embedding" begin vocab_size, embed_size = 10, 4 - m = Flux.Embedding(vocab_size, embed_size) + m = Flux.Embedding(vocab_size => embed_size) @test size(m.weight) == (embed_size, vocab_size) x = rand(1:vocab_size, 3) From 5ff828093e1f13e176f769b04b1ff282c2fa33a5 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Thu, 15 Jul 2021 10:11:20 +0530 Subject: [PATCH 4/9] updated Embedding docstring Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ca58407f23..f71c5365dc 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -508,7 +508,7 @@ julia> m([6, 15, 15]) # applied to a batch 4.01 22.01 22.01 5.01 23.01 23.01 -julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26)) +julia> ans == m(Flux.onehotbatch("foo", 'a':'z')) true ``` """ From 73d7281266ad62e05e20438ab07ef1cfc03b7f7e Mon Sep 17 00:00:00 2001 From: Manikya Date: Thu, 15 Jul 2021 10:16:11 +0530 Subject: [PATCH 5/9] updated and exported Embedding --- src/Flux.jl | 2 +- src/layers/basic.jl | 9 +++------ src/outputsize.jl | 3 +-- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 5f906a0528..a38115f1ef 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,7 +14,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient using ChainRulesCore -export Chain, Dense, Maxout, SkipConnection, Parallel, +export Chain, Dense, Maxout, SkipConnection, Parallel, Embedding, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f71c5365dc..7c36d5a9ae 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -521,13 +521,10 @@ end Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims))) (m::Embedding)(x::Integer) = m.weight[:, x] -(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) +(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) - -function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(onecold(x)) -end +(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) +(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix function Base.show(io::IO, m::Embedding) print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))") diff --git a/src/outputsize.jl b/src/outputsize.jl index 3613502671..c79d8df284 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -169,5 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) end end -(m::Embedding)(x::AbstractVector{<:Nil}) = fill(nil, size(m.weight, 1), length(x)) -(m::Embedding)(x::AbstractArray{<:Nil}) = fill(nil, size(m.weight, 1), size(x)...) +(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x)) From 6e1e66d79c38dcfab4c77938ec699074d55ce974 Mon Sep 17 00:00:00 2001 From: Manikya Date: Thu, 15 Jul 2021 10:25:37 +0530 Subject: [PATCH 6/9] updated Embedding tests --- test/cuda/layers.jl | 4 ++-- test/layers/basic.jl | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 13a3830fb6..d912ec85c5 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -128,12 +128,12 @@ gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2) -gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2) +gpu_gradtest("Embedding repeated indices", embedding, rand(1:50, 10^6), 50 => 2) gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2) gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2) gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2) gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:50, 10^6), 50), 50 => 2) @testset "function layers" begin x = rand(Float32, 3,3) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 53b6935d0c..8ad58b26f1 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -284,9 +284,9 @@ import Flux: activations @test y isa Matrix{Float32} @test y ≈ m.weight[:,x] x2 = OneHotMatrix(x, vocab_size) - y2 = m(x2) - @test y2 isa Matrix{Float32} - @test y2 ≈ y + @test m(x2) isa Matrix{Float32} + @test m(x2) ≈ y + @test m(collect(x2)) ≈ y @test_throws DimensionMismatch m(OneHotMatrix(x, 1000)) x = rand(1:vocab_size, 3, 4) @@ -297,6 +297,9 @@ import Flux: activations @test m(2) ≈ m.weight[:,2] @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] @test_throws DimensionMismatch m(OneHotVector(3, 1000)) + + x = onehotbatch(rand(1:vocab_size, 4, 3, 4, 5), 1:vocab_size) + @test m(x) ≈ m(onecold(x)) end end From 2d80696f6911914037756d175dc13020ce16ad56 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Sun, 13 Feb 2022 21:56:44 +0530 Subject: [PATCH 7/9] add outputsize special case for NNlib.gather --- src/layers/basic.jl | 2 +- src/outputsize.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7c36d5a9ae..cd3229ec0c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -521,7 +521,7 @@ end Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims))) (m::Embedding)(x::Integer) = m.weight[:, x] -(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x) +(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) (m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) (m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix diff --git a/src/outputsize.jl b/src/outputsize.jl index c79d8df284..54ce779d1f 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -169,4 +169,4 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) end end -(m::Embedding)(x::AbstractVecOrMat{<:Nil}) = fill(nil, size(m.weight, 1), length(x)) +NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...) \ No newline at end of file From a2f096143d77b961dae64dccb45b24808ed5ed50 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Tue, 1 Mar 2022 17:20:11 +0530 Subject: [PATCH 8/9] Update src/layers/basic.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/layers/basic.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index cd3229ec0c..9f47cba309 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -483,7 +483,8 @@ function Base.show(io::IO, m::Parallel) end """ - Embedding(in => out; init=randn) + Embedding(in => out; init=randn32) + Embedding(weight::AbstractMatrix) A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. From ef13026be74addce13ebbfe1670a1d0250c0c151 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Tue, 1 Mar 2022 17:42:06 +0530 Subject: [PATCH 9/9] updated tests and outputsize gather --- src/outputsize.jl | 4 +++- test/cuda/layers.jl | 4 ++-- test/outputsize.jl | 7 +++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/outputsize.jl b/src/outputsize.jl index 54ce779d1f..02c92c7174 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -169,4 +169,6 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) end end -NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...) \ No newline at end of file +function NNlib.gather(src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{<:Nil}) where {Tsrc, Nsrc} + fill(nil, (size(src)[1:Nsrc-1]..., size(idx)...)) +end \ No newline at end of file diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index d912ec85c5..3c28c6a268 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -128,12 +128,12 @@ gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2) -gpu_gradtest("Embedding repeated indices", embedding, rand(1:50, 10^6), 50 => 2) +gpu_gradtest("Embedding repeated indices", embedding, rand(1:10, 10^3), 10 => 2) gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2) gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2) gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2) gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:50, 10^6), 50), 50 => 2) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:10, 10^3), 10), 10 => 2) @testset "function layers" begin x = rand(Float32, 3,3) diff --git a/test/outputsize.jl b/test/outputsize.jl index 2c90811dcb..d1ae94223a 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -155,3 +155,10 @@ end @test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16) @test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1) end + +@testset "embedding" begin + m = Embedding(3=>5) + @test outputsize(m, (2,)) == (5, 2) + @test outputsize(m, (2, 3)) == (5, 2, 3) + @test outputsize(m, (2, 3, 4)) == (5, 2, 3, 4) +end