Skip to content

Commit f0c603c

Browse files
committed
Hotfix for ViT on GPU
1 parent ee344a9 commit f0c603c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/layers/attention.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
5050
scale = convert(T, sqrt(size(query, 1) / m.nheads))
5151
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
5252
seq_len * batch_size)
53-
query_reshaped = reshape(query, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
53+
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
5454
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
55-
value_reshaped = reshape(value, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
55+
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
5656
pre_projection = reshape(batched_mul(attention, value_reshaped),
5757
(nfeatures, seq_len, batch_size))
5858
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))

0 commit comments

Comments
 (0)