Skip to content

Commit 7175c36

Browse files
add embedding layer
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent 46b73a8 commit 7175c36

File tree

5 files changed

+95
-1
lines changed

5 files changed

+95
-1
lines changed

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ SkipConnection
5858
Parallel
5959
Flux.Bilinear
6060
Flux.Diagonal
61+
Flux.Embedding
6162
```
6263

6364
## Normalisation & Regularisation

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

1313
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
14-
RNN, LSTM, GRU,
14+
RNN, LSTM, GRU, Embedding,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1717
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,

src/layers/basic.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on a given input.
88
`m[1:3](x)` will calculate the output of the first three layers.
99
1010
# Examples
11+
1112
```jldoctest
1213
julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -428,3 +429,55 @@ function Base.show(io::IO, m::Parallel)
428429
join(io, m.layers, ", ")
429430
print(io, ")")
430431
end
432+
433+
"""
434+
Embedding(in, out; init=randn)
435+
436+
A lookup table that stores embeddings of dimension `out`
437+
for a vocabulary of size `in`.
438+
439+
This layers is often used to store word embeddings and retrieve them using indices.
440+
The input to the layer can be either a vector of indexes
441+
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442+
443+
# Examples
444+
445+
```julia-repl
446+
julia> vocab_size, embed_size = 1000, 4;
447+
448+
julia> model = Embedding(vocab_size, embed_size)
449+
Embedding(1000, 4)
450+
451+
julia> vocab_idxs = [1, 722, 53, 220, 3]
452+
453+
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
454+
455+
julia> model(x)
456+
4×5 Matrix{Float32}:
457+
0.91139 0.670462 0.463217 0.670462 0.110932
458+
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
459+
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
460+
-0.497621 0.87595 -0.870251 0.87595 -0.772696
461+
```
462+
463+
julia> model(vocab_idxs) == model(x)
464+
true
465+
"""
466+
struct Embedding{W}
467+
weight::W
468+
end
469+
470+
@functor Embedding
471+
472+
function Embedding(in::Integer, out::Integer;
473+
init = (i...) -> randn(Float32, i...))
474+
return Embedding(init(out, in))
475+
end
476+
477+
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
478+
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
479+
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
480+
481+
function Base.show(io::IO, m::Embedding)
482+
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
483+
end

test/cuda/layers.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,18 @@ end
259259
end
260260
end
261261
end
262+
263+
@testset "Embedding" begin
264+
vocab_size, embed_size = 10, 4
265+
m = Embedding(vocab_size, embed_size)
266+
x = rand(1:vocab_size, 3)
267+
y = m(x)
268+
m_g = m |> gpu
269+
x_g = x |> gpu
270+
y_g = m_g(x_g)
271+
@test collect(y_g) == y
272+
gs = gradient(() -> sum(tanh.(m(x))), params(m))
273+
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
274+
@test collect(gs_g[m_g.weight]) gs[m.weight]
275+
end
276+

test/layers/basic.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,29 @@ import Flux: activations
191191
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
192192
end
193193
end
194+
195+
@testset "Embedding" begin
196+
vocab_size, embed_size = 10, 4
197+
m = Embedding(vocab_size, embed_size)
198+
@test size(m.weight) == (embed_size, vocab_size)
199+
200+
x = rand(1:vocab_size, 3)
201+
y = m(x)
202+
@test y isa Matrix{Float32}
203+
@test y m.weight[:,x]
204+
x2 = OneHotMatrix(x, vocab_size)
205+
y2 = m(x2)
206+
@test y2 isa Matrix{Float32}
207+
@test y2 y
208+
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
209+
210+
x = rand(1:vocab_size, 3, 4)
211+
y = m(x)
212+
@test y isa Array{Float32, 3}
213+
@test size(y) == (embed_size, 3, 4)
214+
215+
@test m(2) m.weight[:,2]
216+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
217+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
218+
end
194219
end

0 commit comments

Comments
 (0)