@@ -487,30 +487,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
487
487
488
488
# Examples
489
489
490
- ```julia-repl
491
- julia> using Flux: Embedding
492
-
493
- julia> vocab_size, embed_size = 1000, 4;
494
-
495
- julia> model = Embedding(vocab_size, embed_size)
496
- Embedding(1000, 4)
497
-
498
- julia> vocab_idxs = [1, 722, 53, 220, 3]
490
+ ```jldoctest
491
+ julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
492
+ Embedding(26 => 2)
499
493
500
- julia> x = OneHotMatrix(vocab_idxs, vocab_size);
494
+ julia> m(5) # embedding vector for 5th element
495
+ 2-element Vector{Float32}:
496
+ 2.01
497
+ 3.01
501
498
502
- julia> model(x)
503
- 4×5 Matrix{Float32}:
504
- 0.91139 0.670462 0.463217 0.670462 0.110932
505
- 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
506
- -0.393626 -0.590136 -0.545422 -0.590136 0.77743
507
- -0.497621 0.87595 -0.870251 0.87595 -0.772696
508
- ```
499
+ julia> m([6, 15, 15]) # applied to a batch
500
+ 2×3 Matrix{Float32}:
501
+ 4.01 22.01 22.01
502
+ 5.01 23.01 23.01
509
503
510
- julia> model(vocab_idxs) == model(x )
504
+ julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26) )
511
505
true
506
+ ```
512
507
"""
513
- struct Embedding{W}
508
+ struct Embedding{W <: AbstractMatrix }
514
509
weight:: W
515
510
end
516
511
@@ -529,5 +524,5 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
529
524
end
530
525
531
526
function Base. show (io:: IO , m:: Embedding )
532
- print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
527
+ print (io, " Embedding($(size (m. weight, 2 )) => $(size (m. weight, 1 )) )" )
533
528
end
0 commit comments