We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ee344a9 commit f0c603cCopy full SHA for f0c603c
src/layers/attention.jl
@@ -50,9 +50,9 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
50
scale = convert(T, sqrt(size(query, 1) / m.nheads))
51
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
52
seq_len * batch_size)
53
- query_reshaped = reshape(query, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
+ query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
54
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
55
- value_reshaped = reshape(value, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
+ value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
56
pre_projection = reshape(batched_mul(attention, value_reshaped),
57
(nfeatures, seq_len, batch_size))
58
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
0 commit comments