-
-
Notifications
You must be signed in to change notification settings - Fork 611
Simplify Embedding
#2084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify Embedding
#2084
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -644,32 +644,42 @@ function Base.show(io::IO, m::PairwiseFusion) | |
end | ||
|
||
""" | ||
Embedding(in => out; init=randn) | ||
Embedding(in => out; init=randn32) | ||
|
||
A lookup table that stores embeddings of dimension `out` | ||
for a vocabulary of size `in`. | ||
for a vocabulary of size `in`, as a trainable matrix. | ||
|
||
This layer is often used to store word embeddings and retrieve them using indices. | ||
The input to the layer can be either a vector of indexes | ||
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). | ||
The input to the layer can be a vocabulary index in `1:in`, an array of indices, | ||
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). | ||
|
||
For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions. | ||
For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`. | ||
|
||
# Examples | ||
```jldoctest | ||
julia> vocab_size, embed_size = 1000, 4; | ||
|
||
julia> model = Flux.Embedding(vocab_size => embed_size) | ||
Embedding(1000 => 4) # 4_000 parameters | ||
|
||
julia> vocab_idxs = [1, 722, 53, 220, 3]; | ||
|
||
julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x) | ||
"1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool" | ||
|
||
julia> model(x) |> summary | ||
"4×5 Matrix{Float32}" | ||
|
||
julia> model(vocab_idxs) == model(x) | ||
julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22)) | ||
Embedding(26 => 4) # 104 parameters | ||
|
||
julia> emb(2) # one column of e.weight (here not random!) | ||
4-element Vector{Float32}: | ||
0.0 | ||
22.0 | ||
0.0 | ||
0.0 | ||
|
||
julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26 | ||
4×7 Matrix{Float32}: | ||
0.0 22.0 0.0 0.0 0.0 0.0 0.0 | ||
0.0 0.0 0.0 0.0 0.0 0.0 0.0 | ||
22.0 0.0 0.0 0.0 0.0 0.0 0.0 | ||
0.0 0.0 0.0 0.0 22.0 0.0 0.0 | ||
|
||
julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n')) | ||
true | ||
|
||
julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions | ||
(4, 10, 1, 12) | ||
``` | ||
""" | ||
struct Embedding{W} | ||
|
@@ -684,10 +694,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini | |
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) | ||
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) | ||
|
||
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} | ||
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) | ||
return m(onecold(x)) | ||
end | ||
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector | ||
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix | ||
Comment on lines
+697
to
+698
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These could instead call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance in the one hot case? If it's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For For a one-hot BitArray, the results will agree. I would guess that onecold is faster but haven't checked. For a generic BitArray, I'm not sure which is mathematically expected really. I think you're saying that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, what you wrote is what I meant re: performance. I was adding that in the one-hot bit array case, we can direct people to Yeah whenever I've come across this type of operation in papers, I see it written as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Mixing two embedding vectors seems less wrong. But probably nobody ever hits this & it's just a way to decouple from OneHotArray types. I don't think we should document that boolean indexing is an option. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think we are happy with the current implementation in the PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think so. I see we had a very similar discussion in #1656 (comment) BTW, I forgot... but same conclusion. |
||
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...) | ||
|
||
function Base.show(io::IO, m::Embedding) | ||
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") | ||
|
Uh oh!
There was an error while loading. Please reload this page.