Skip to content

Commit 0ae379e

Browse files
manikyabardmcabbottdarsnack
authored
Apply suggestions from code review
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent 85bfcca commit 0ae379e

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

src/layers/basic.jl

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -385,28 +385,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
385385
386386
# Examples
387387
388-
```julia-repl
389-
julia> vocab_size, embed_size = 1000, 4;
390-
391-
julia> model = Embedding(vocab_size, embed_size)
392-
Embedding(1000, 4)
393-
394-
julia> vocab_idxs = [1, 722, 53, 220, 3]
388+
```jldoctest
389+
julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
390+
Embedding(26 => 2)
395391
396-
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
392+
julia> m(5) # embedding vector for 5th element
393+
2-element Vector{Float32}:
394+
2.01
395+
3.01
397396
398-
julia> model(x)
399-
4×5 Matrix{Float32}:
400-
0.91139 0.670462 0.463217 0.670462 0.110932
401-
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
402-
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
403-
-0.497621 0.87595 -0.870251 0.87595 -0.772696
404-
```
397+
julia> m([6, 15, 15]) # applied to a batch
398+
2×3 Matrix{Float32}:
399+
4.01 22.01 22.01
400+
5.01 23.01 23.01
405401
406-
julia> model(vocab_idxs) == model(x)
402+
julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26))
407403
true
404+
```
408405
"""
409-
struct Embedding{W}
406+
struct Embedding{W <: AbstractMatrix}
410407
weight::W
411408
end
412409

@@ -417,11 +414,11 @@ function Embedding(in::Integer, out::Integer;
417414
return Embedding(init(out, in))
418415
end
419416

420-
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
421-
(m::Embedding)(x::Int) = m.weight[:, x]
422-
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
423-
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
417+
(m::Embedding)(x::Union{OneHotLikeVector, OneHotLikeMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
418+
(m::Embedding)(x::Integer) = m.weight[:, x]
419+
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
420+
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)
424421

425422
function Base.show(io::IO, m::Embedding)
426-
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
423+
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")
427424
end

0 commit comments

Comments
 (0)