Skip to content

Commit 6171e75

Browse files
committed
Fix MQA V2 scale and out shape
1 parent 851e074 commit 6171e75

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

tests/test_layers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import pytest
12
import torch
23
import torch.nn as nn
34

4-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, MultiQueryAttentionV2
56

67
import importlib
78
import os
@@ -119,3 +120,18 @@ def test_get_act_fn_none():
119120
assert get_act_fn(None) is None
120121
assert get_act_fn('') is None
121122

123+
@pytest.mark.parametrize("dim", [128])
124+
@pytest.mark.parametrize("dim_out", [128, 256])
125+
@pytest.mark.parametrize("use_m", [True, False])
126+
def test_mqa_v2(dim, dim_out, use_m):
127+
mqa = MultiQueryAttentionV2(dim, dim_out)
128+
129+
x = torch.randn(1, dim, 32, 48)
130+
if use_m:
131+
m = torch.randn(1, dim, 16, 24)
132+
else:
133+
m = None
134+
135+
y = mqa(x, m=m)
136+
137+
assert (y.shape) == (1, dim_out, 32, 48)

timm/layers/attention2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ def _reshape_input(self, t):
5959

6060
def forward(self, x, m: Optional[torch.Tensor] = None):
6161
"""Run layer computation."""
62-
s = x.shape
63-
m = m or x
62+
b, _, h, w = x.shape
63+
m = m if m is not None else x
6464

6565
reshaped_x = self._reshape_input(x)
6666
reshaped_m = self._reshape_input(m)
6767

6868
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
6969
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
7070

71-
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
71+
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
7272
attn = attn.softmax(dim=-1)
7373
attn = self.attn_drop(attn)
7474

7575
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
7676
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
77-
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
77+
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
7878
result = self.proj_drop(result)
79-
return result.reshape(s)
79+
return result.reshape(b, -1, h, w)
8080

8181

8282
class MultiQueryAttention2d(nn.Module):

0 commit comments

Comments
 (0)