Skip to content

Commit 3f57415

Browse files
committed
make outputsize work with Embedding
1 parent 090f043 commit 3f57415

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

src/layers/basic.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ julia> model(x) |> summary
670670
671671
julia> model(vocab_idxs) == model(x)
672672
true
673+
674+
julia> Flux.outputsize(model, size(vocab_idxs)) # outputsize wants indices, not OneHotArray
675+
(4, 5)
673676
```
674677
"""
675678
struct Embedding{W}
@@ -681,8 +684,11 @@ end
681684
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
682685

683686
(m::Embedding)(x::Integer) = m.weight[:, x]
684-
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
685-
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
687+
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
688+
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)
689+
690+
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
691+
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
686692

687693
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
688694
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))

src/outputsize.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ DimensionMismatch("Input channels must match! (7 vs. 3)")
8787
julia> outputsize([Dense(10 => 4), Dense(4 => 2)], (10, 1)) # Vector of layers becomes a Chain
8888
(2, 1)
8989
```
90+
91+
Limitations:
92+
* `Embedding` accepts either integers or one-hot arrays, and `ohx = onehotbatch(x, ...)`
93+
has one more dimension than `x`. Here `outputsize` uses `size(x)`.
94+
* At present `outputsize` does not work with recurrent layers,
95+
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug.
9096
"""
9197
function outputsize(m, inputsizes::Tuple...; padbatch=false)
9298
x = nil_input(padbatch, inputsizes...)

test/outputsize.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
m = Dense(10, 5)
66
@test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1)
77
@test outputsize(m, (10,); padbatch=true) == (5, 1)
8+
@test outputsize(m, (10,)) == (5,)
9+
@test outputsize(m, (10, 6, 7)) == (5, 6, 7)
810

911
m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2))
1012
@test outputsize(m, (10,); padbatch=true) == (2, 1)
@@ -41,6 +43,19 @@
4143
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
4244
end
4345

46+
@testset "embeddings" begin
47+
# Here outputsize expects indices, not one-hot representation:
48+
m = Embedding(3 => 4)
49+
@test outputsize(m, (3, 7)) == (4, 3, 7) == size(m(rand(1:3, 3, 7)))
50+
@test outputsize(m, (5, 6, 7)) == (4, 5, 6, 7) == size(m(rand(1:3, 5, 6, 7)))
51+
52+
m = Chain(x -> Flux.onehotbatch(x, 1:5), Embedding(5 => 7))
53+
@test size(m([3,4])) == (7, 2)
54+
@test outputsize(m, (2,)) == (7, 2)
55+
# This works because Flux.onehotbatch([nil, nil], 1:5) makes a 5×2 OneHotMatrix
56+
# But e.g. Flux.onehotbatch([nil, nil], 'a':'e') will not work.
57+
end
58+
4459
@testset "multiple inputs" begin
4560
m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu))
4661
@test outputsize(m, (2,), (3,)) == (10,)

0 commit comments

Comments
 (0)