Skip to content

Commit b22cd2d

Browse files
cl/embed
1 parent ea41ea6 commit b22cd2d

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Maxout
5858
SkipConnection
5959
Parallel
6060
Bilinear
61+
Embedding
6162
```
6263

6364
## Normalisation & Regularisation

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1212
export gradient
1313

1414
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
15-
RNN, LSTM, GRU,
15+
RNN, LSTM, GRU, Embedding,
1616
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1717
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1818
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,

src/layers/basic.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on a given input.
88
`m[1:3](x)` will calculate the output of the first three layers.
99
1010
# Examples
11+
1112
```jldoctest
1213
julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -337,7 +338,8 @@ If called with multiple inputs, they are `zip`ped with the layers, thus `Paralle
337338
338339
```jldoctest
339340
julia> model = Chain(Dense(3, 5),
340-
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
341+
Parallel(vcat, De
342+
print(io, ")")nse(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
341343
Dense(8, 17));
342344
343345
julia> size(model(rand(3)))
@@ -371,3 +373,55 @@ function Base.show(io::IO, m::Parallel)
371373
join(io, m.layers, ", ")
372374
print(io, ")")
373375
end
376+
377+
"""
378+
Embedding(in, out; init=randn)
379+
380+
A lookup table that stores embeddings of dimension `out`
381+
for a vocabulary of size `in`.
382+
383+
This layers is often used to store word embeddings and retrieve them using indices.
384+
The input to the layer can be either a vector of indexes
385+
or the corresponding onehot encoding.
386+
387+
# Examples
388+
389+
```julia-repl
390+
julia> vocab_size, embed_size = 1000, 4;
391+
392+
julia> model = Embedding(vocab_size, embed_size)
393+
Embedding(1000, 4)
394+
395+
julia> vocab_idxs = [1, 722, 53, 220, 3]
396+
397+
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
398+
399+
julia> model(x)
400+
4×5 Matrix{Float32}:
401+
0.91139 0.670462 0.463217 0.670462 0.110932
402+
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
403+
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
404+
-0.497621 0.87595 -0.870251 0.87595 -0.772696
405+
```
406+
407+
julia> model(vocab_idxs) # same as above
408+
"""
409+
struct Embedding{W}
410+
weight::W
411+
end
412+
413+
@functor Embedding
414+
415+
function Embedding(in::Integer, out::Integer;
416+
init = (i...) -> randn(Float32, i...))
417+
return Embedding(init(out, in))
418+
end
419+
420+
(m::Embedding)(x::OneHotMatrix) = m.weight * x # equivalent to m.weight[:, onecold(x)]
421+
(m::Embedding)(x::OneHotVector) = m.weight * x
422+
(m::Embedding)(x::AbstractVector) = m.weight[:, x]
423+
(m::Embedding)(x::Int) = m.weight[:, x]
424+
425+
function Base.show(io::IO, m::Embedding)
426+
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
427+
end

test/layers/basic.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,25 @@ import Flux: activations
153153
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
154154
end
155155
end
156+
157+
@testset "Embedding" begin
158+
vocab_size, embed_size = 10, 4
159+
m = Embedding(vocab_size, embed_size)
160+
@test size(m.weight) == (embed_size, vocab_size)
161+
162+
x = rand(1:vocab_size, 3)
163+
y = m(x)
164+
@test y isa Matrix{Float32}
165+
@test y m.weight[:,x]
166+
167+
x2 = OneHotMatrix(x, vocab_size)
168+
y2 = m(x2)
169+
@test y2 isa Matrix{Float32}
170+
@test y2 y
171+
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
172+
173+
@test m(2) m.weight[:,2]
174+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
175+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
176+
end
156177
end

0 commit comments

Comments
 (0)