From 1f541b441a6951da36e8006256fb4cd746095654 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 14 Oct 2022 19:34:16 -0400 Subject: [PATCH 1/3] don't specialise on OneHotMatrix, but do call reshape --- src/layers/basic.jl | 18 ++++++++++-------- test/layers/basic.jl | 20 +++++++++++++++----- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 42813cb5f7..df200cd9e0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -644,14 +644,17 @@ function Base.show(io::IO, m::PairwiseFusion) end """ - Embedding(in => out; init=randn) + Embedding(in => out; init=randn32) A lookup table that stores embeddings of dimension `out` -for a vocabulary of size `in`. +for a vocabulary of size `in`, as a trainable matrix. This layer is often used to store word embeddings and retrieve them using indices. -The input to the layer can be either a vector of indexes -or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). +The input to the layer can be a vocabulary index in `1:in`, an array of indices, +or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). + +For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions. +For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`. # Examples ```jldoctest @@ -684,10 +687,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini (m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) -function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(onecold(x)) -end +(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector +(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix +(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...) function Base.show(io::IO, m::Embedding) print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d66aad4f56..1f9d30dec5 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -289,9 +289,17 @@ import Flux: activations @testset "Embedding" begin vocab_size, embed_size = 10, 4 - m = Flux.Embedding(vocab_size, embed_size) + m = Embedding(vocab_size, embed_size) @test size(m.weight) == (embed_size, vocab_size) + + # one index + @test m(1) isa Vector{Float32} + @test m(2) ≈ m.weight[:,2] + @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] + @test_throws DimensionMismatch m(OneHotVector(3, 1000)) + @test m(4) ≈ m((1:vocab_size) .== 4) + # a batch of indices x = rand(1:vocab_size, 3) y = m(x) @test y isa Matrix{Float32} @@ -301,15 +309,17 @@ import Flux: activations @test y2 isa Matrix{Float32} @test y2 ≈ y @test_throws DimensionMismatch m(OneHotMatrix(x, 1000)) + @test y ≈ m(x' .== (1:vocab_size)) + # more dimensions via reshape x = rand(1:vocab_size, 3, 4) y = m(x) @test y isa Array{Float32, 3} @test size(y) == (embed_size, 3, 4) - - @test m(2) ≈ m.weight[:,2] - @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] - @test_throws DimensionMismatch m(OneHotVector(3, 1000)) + x3 = onehotbatch(x, 1:1:vocab_size) + @test size(x3) == (vocab_size, 3, 4) + y3 = m(x3) + @test size(y3) == (embed_size, 3, 4) end end From 3a22bab7a7c2b0892e47fcc01642a9372e69e5e1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 14 Oct 2022 19:52:49 -0400 Subject: [PATCH 2/3] also give a non-random example where the vectors can be printed. Do it without 5 named variables, and show that the point of onehot is variables which aren't 1:n already. Also show result of higher-rank input. --- src/layers/basic.jl | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index df200cd9e0..f3ece1d90e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -654,25 +654,32 @@ The input to the layer can be a vocabulary index in `1:in`, an array of indices, or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions. -For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`. +For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`. # Examples ```jldoctest -julia> vocab_size, embed_size = 1000, 4; - -julia> model = Flux.Embedding(vocab_size => embed_size) -Embedding(1000 => 4) # 4_000 parameters - -julia> vocab_idxs = [1, 722, 53, 220, 3]; - -julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x) -"1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool" - -julia> model(x) |> summary -"4×5 Matrix{Float32}" - -julia> model(vocab_idxs) == model(x) +julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22)) +Embedding(26 => 4) # 104 parameters + +julia> emb(2) # one column of e.weight (here not random!) +4-element Vector{Float32}: + 0.0 + 22.0 + 0.0 + 0.0 + +julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26 +4×7 Matrix{Float32}: + 0.0 22.0 0.0 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 22.0 0.0 0.0 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 22.0 0.0 0.0 + +julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n')) true + +julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions +(4, 10, 1, 12) ``` """ struct Embedding{W} From 4676a82376872eef3ef9e86dee2a20cc4dfc60c4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 17 Oct 2022 09:31:24 -0400 Subject: [PATCH 3/3] restrict type of field to AbstractMatrix --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f3ece1d90e..2a3bc9131c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -682,7 +682,7 @@ julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions (4, 10, 1, 12) ``` """ -struct Embedding{W} +struct Embedding{W<:AbstractMatrix} weight::W end