Skip to content

Commit 5bd0bf3

Browse files
authored
Merge pull request #169 from theabhirath/vit-hotfix
Hotfix for ViT on GPU
2 parents ee344a9 + 159a695 commit 5bd0bf3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Metalhead"
22
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
3-
version = "0.7.2-DEV"
3+
version = "0.7.2"
44

55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

src/layers/attention.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
5050
scale = convert(T, sqrt(size(query, 1) / m.nheads))
5151
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
5252
seq_len * batch_size)
53-
query_reshaped = reshape(query, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
53+
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
5454
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
55-
value_reshaped = reshape(value, nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
55+
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size)
5656
pre_projection = reshape(batched_mul(attention, value_reshaped),
5757
(nfeatures, seq_len, batch_size))
5858
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))

0 commit comments

Comments
 (0)