Skip to content

Commit 0031340

Browse files
cl/embed
1 parent 9410677 commit 0031340

File tree

4 files changed

+80
-3
lines changed

4 files changed

+80
-3
lines changed

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ SkipConnection
5858
Parallel
5959
Flux.Bilinear
6060
Flux.Diagonal
61+
Flux.Embedding
6162
```
6263

6364
## Normalisation & Regularisation

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

1313
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
14-
RNN, LSTM, GRU,
14+
RNN, LSTM, GRU, Embedding,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1717
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,

src/layers/basic.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on a given input.
88
`m[1:3](x)` will calculate the output of the first three layers.
99
1010
# Examples
11+
1112
```jldoctest
1213
julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -388,7 +389,8 @@ If called with multiple inputs, they are `zip`ped with the layers, thus `Paralle
388389
389390
```jldoctest
390391
julia> model = Chain(Dense(3, 5),
391-
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
392+
Parallel(vcat, De
393+
print(io, ")")nse(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
392394
Dense(8, 17));
393395
394396
julia> size(model(rand(3)))
@@ -421,4 +423,57 @@ function Base.show(io::IO, m::Parallel)
421423
print(io, "Parallel(", m.connection, ", ")
422424
join(io, m.layers, ", ")
423425
print(io, ")")
424-
end
426+
end
427+
428+
"""
429+
Embedding(in, out; init=randn)
430+
431+
A lookup table that stores embeddings of dimension `out`
432+
for a vocabulary of size `in`.
433+
434+
This layers is often used to store word embeddings and retrieve them using indices.
435+
The input to the layer can be either a vector of indexes
436+
or the corresponding onehot encoding.
437+
438+
# Examples
439+
440+
```julia-repl
441+
julia> vocab_size, embed_size = 1000, 4;
442+
443+
julia> model = Embedding(vocab_size, embed_size)
444+
Embedding(1000, 4)
445+
446+
julia> vocab_idxs = [1, 722, 53, 220, 3]
447+
448+
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
449+
450+
julia> model(x)
451+
4×5 Matrix{Float32}:
452+
0.91139 0.670462 0.463217 0.670462 0.110932
453+
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
454+
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
455+
-0.497621 0.87595 -0.870251 0.87595 -0.772696
456+
```
457+
458+
julia> model(vocab_idxs) # same as above
459+
"""
460+
struct Embedding{W}
461+
weight::W
462+
end
463+
464+
@functor Embedding
465+
466+
function Embedding(in::Integer, out::Integer;
467+
init = (i...) -> randn(Float32, i...))
468+
return Embedding(init(out, in))
469+
end
470+
471+
(m::Embedding)(x::OneHotMatrix) = m.weight * x # equivalent to m.weight[:, onecold(x)]
472+
(m::Embedding)(x::OneHotVector) = m.weight * x
473+
(m::Embedding)(x::AbstractVector) = m.weight[:, x]
474+
(m::Embedding)(x::Int) = m.weight[:, x]
475+
476+
function Base.show(io::IO, m::Embedding)
477+
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
478+
end
479+
>>>>>>> b22cd2dc (cl/embed)

test/layers/basic.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,25 @@ import Flux: activations
191191
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
192192
end
193193
end
194+
195+
@testset "Embedding" begin
196+
vocab_size, embed_size = 10, 4
197+
m = Embedding(vocab_size, embed_size)
198+
@test size(m.weight) == (embed_size, vocab_size)
199+
200+
x = rand(1:vocab_size, 3)
201+
y = m(x)
202+
@test y isa Matrix{Float32}
203+
@test y m.weight[:,x]
204+
205+
x2 = OneHotMatrix(x, vocab_size)
206+
y2 = m(x2)
207+
@test y2 isa Matrix{Float32}
208+
@test y2 y
209+
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
210+
211+
@test m(2) m.weight[:,2]
212+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
213+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
214+
end
194215
end

0 commit comments

Comments
 (0)