Skip to content

Commit e300cdd

Browse files
committed
fix multiheaded qk rmsnorm in nViT
1 parent 36ddc7a commit e300cdd

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.4',
9+
version = '1.8.5',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/normalized_vit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __init__(
7676

7777
self.dropout = dropout
7878

79-
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
80-
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
79+
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
80+
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
8181

8282
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
8383
self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -90,15 +90,15 @@ def forward(
9090
):
9191
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
9292

93-
q = q * self.q_scale
94-
k = k * self.k_scale
95-
9693
q, k, v = map(self.split_heads, (q, k, v))
9794

9895
# query key rmsnorm
9996

10097
q, k = map(l2norm, (q, k))
10198

99+
q = q * self.q_scale
100+
k = k * self.k_scale
101+
102102
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
103103

104104
out = F.scaled_dot_product_attention(

0 commit comments

Comments
 (0)