Skip to content

Commit 2cbf391

Browse files
committed
don't specialise on OneHotMatrix, but do call reshape
1 parent c7ed5fe commit 2cbf391

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

src/layers/basic.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,17 @@ function Base.show(io::IO, m::PairwiseFusion)
644644
end
645645

646646
"""
647-
Embedding(in => out; init=randn)
647+
Embedding(in => out; init=randn32)
648648
649649
A lookup table that stores embeddings of dimension `out`
650-
for a vocabulary of size `in`.
650+
for a vocabulary of size `in`, as a trainable matrix.
651651
652652
This layer is often used to store word embeddings and retrieve them using indices.
653-
The input to the layer can be either a vector of indexes
654-
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653+
The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654+
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655+
656+
For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657+
For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`.
655658
656659
# Examples
657660
```jldoctest
@@ -684,10 +687,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684687
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
685688
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
686689

687-
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
688-
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
689-
return m(onecold(x))
690-
end
690+
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
691+
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
692+
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
691693

692694
function Base.show(io::IO, m::Embedding)
693695
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")

test/layers/basic.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,17 @@ import Flux: activations
289289

290290
@testset "Embedding" begin
291291
vocab_size, embed_size = 10, 4
292-
m = Flux.Embedding(vocab_size, embed_size)
292+
m = Embedding(vocab_size, embed_size)
293293
@test size(m.weight) == (embed_size, vocab_size)
294+
295+
# one index
296+
@test m(1) isa Vector{Float32}
297+
@test m(2) m.weight[:,2]
298+
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
299+
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
300+
@test m(4) m((1:vocab_size) .== 4)
294301

302+
# a batch of indices
295303
x = rand(1:vocab_size, 3)
296304
y = m(x)
297305
@test y isa Matrix{Float32}
@@ -301,15 +309,17 @@ import Flux: activations
301309
@test y2 isa Matrix{Float32}
302310
@test y2 y
303311
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
312+
@test y m(x' .== (1:vocab_size))
304313

314+
# more dimensions via reshape
305315
x = rand(1:vocab_size, 3, 4)
306316
y = m(x)
307317
@test y isa Array{Float32, 3}
308318
@test size(y) == (embed_size, 3, 4)
309-
310-
@test m(2) m.weight[:,2]
311-
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
312-
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
319+
x3 = onehotbatch(x, 1:1:vocab_size)
320+
@test size(x3) == (vocab_size, 3, 4)
321+
y3 = m(x3)
322+
@test size(y3) == (embed_size, 3, 4)
313323
end
314324
end
315325

0 commit comments

Comments
 (0)