Skip to content

Commit 1e2cffd

Browse files
manikyabardmcabbottdarsnack
committed
Apply suggestions from code review
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent 70d90cf commit 1e2cffd

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/layers/basic.jl

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -487,30 +487,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
487487
488488
# Examples
489489
490-
```julia-repl
491-
julia> using Flux: Embedding
492-
493-
julia> vocab_size, embed_size = 1000, 4;
494-
495-
julia> model = Embedding(vocab_size, embed_size)
496-
Embedding(1000, 4)
497-
498-
julia> vocab_idxs = [1, 722, 53, 220, 3]
490+
```jldoctest
491+
julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
492+
Embedding(26 => 2)
499493
500-
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
494+
julia> m(5) # embedding vector for 5th element
495+
2-element Vector{Float32}:
496+
2.01
497+
3.01
501498
502-
julia> model(x)
503-
4×5 Matrix{Float32}:
504-
0.91139 0.670462 0.463217 0.670462 0.110932
505-
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
506-
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
507-
-0.497621 0.87595 -0.870251 0.87595 -0.772696
508-
```
499+
julia> m([6, 15, 15]) # applied to a batch
500+
2×3 Matrix{Float32}:
501+
4.01 22.01 22.01
502+
5.01 23.01 23.01
509503
510-
julia> model(vocab_idxs) == model(x)
504+
julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26))
511505
true
506+
```
512507
"""
513-
struct Embedding{W}
508+
struct Embedding{W <: AbstractMatrix}
514509
weight::W
515510
end
516511

@@ -529,5 +524,5 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
529524
end
530525

531526
function Base.show(io::IO, m::Embedding)
532-
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
527+
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")
533528
end

0 commit comments

Comments
 (0)