@@ -385,28 +385,25 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
385
385
386
386
# Examples
387
387
388
- ```julia-repl
389
- julia> vocab_size, embed_size = 1000, 4;
390
-
391
- julia> model = Embedding(vocab_size, embed_size)
392
- Embedding(1000, 4)
393
-
394
- julia> vocab_idxs = [1, 722, 53, 220, 3]
388
+ ```jldoctest
389
+ julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
390
+ Embedding(26 => 2)
395
391
396
- julia> x = OneHotMatrix(vocab_idxs, vocab_size);
392
+ julia> m(5) # embedding vector for 5th element
393
+ 2-element Vector{Float32}:
394
+ 2.01
395
+ 3.01
397
396
398
- julia> model(x)
399
- 4×5 Matrix{Float32}:
400
- 0.91139 0.670462 0.463217 0.670462 0.110932
401
- 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
402
- -0.393626 -0.590136 -0.545422 -0.590136 0.77743
403
- -0.497621 0.87595 -0.870251 0.87595 -0.772696
404
- ```
397
+ julia> m([6, 15, 15]) # applied to a batch
398
+ 2×3 Matrix{Float32}:
399
+ 4.01 22.01 22.01
400
+ 5.01 23.01 23.01
405
401
406
- julia> model(vocab_idxs) == model(x )
402
+ julia> ans == m(Flux.OneHotMatrix([6, 15, 15], 26) )
407
403
true
404
+ ```
408
405
"""
409
- struct Embedding{W}
406
+ struct Embedding{W <: AbstractMatrix }
410
407
weight:: W
411
408
end
412
409
@@ -417,11 +414,11 @@ function Embedding(in::Integer, out::Integer;
417
414
return Embedding (init (out, in))
418
415
end
419
416
420
- (m:: Embedding )(x:: Union{OneHotVector, OneHotMatrix } ) = m. weight * x # equivalent to m.weight[:,onecold(x)]
421
- (m:: Embedding )(x:: Int ) = m. weight[:, x]
422
- (m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
423
- (m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
417
+ (m:: Embedding )(x:: Union{OneHotLikeVector, OneHotLikeMatrix } ) = m. weight * x # equivalent to m.weight[:,onecold(x)]
418
+ (m:: Embedding )(x:: Integer ) = m. weight[:, x]
419
+ (m:: Embedding )(x:: AbstractVector{<:Integer} ) = NNlib. gather (m. weight, x)
420
+ (m:: Embedding )(x:: AbstractArray{<:Integer} ) = reshape (m (vec (x)), :, size (x)... )
424
421
425
422
function Base. show (io:: IO , m:: Embedding )
426
- print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
423
+ print (io, " Embedding($(size (m. weight, 2 )) => $(size (m. weight, 1 )) )" )
427
424
end
0 commit comments