Skip to content

Commit 36ddc7a

Browse files
committed
go all the way with the normalized vit, fix some scales
1 parent 1d1a63f commit 36ddc7a

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.2',
9+
version = '1.8.4',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/normalized_vit.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,18 +179,18 @@ def __init__(
179179

180180
self.to_patch_embedding = nn.Sequential(
181181
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
182-
nn.LayerNorm(patch_dim),
183-
nn.Linear(patch_dim, dim),
184-
nn.LayerNorm(dim),
182+
NormLinear(patch_dim, dim, norm_dim_in = False),
185183
)
186184

187-
self.abs_pos_emb = nn.Embedding(num_patches, dim)
185+
self.abs_pos_emb = NormLinear(dim, num_patches)
188186

189187
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
190188

191189
# layers
192190

193191
self.dim = dim
192+
self.scale = dim ** 0.5
193+
194194
self.layers = ModuleList([])
195195
self.residual_lerp_scales = nn.ParameterList([])
196196

@@ -201,8 +201,8 @@ def __init__(
201201
]))
202202

203203
self.residual_lerp_scales.append(nn.ParameterList([
204-
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
205-
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
204+
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
205+
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
206206
]))
207207

208208
self.logit_scale = nn.Parameter(torch.ones(num_classes))
@@ -225,22 +225,23 @@ def forward(self, images):
225225

226226
tokens = self.to_patch_embedding(images)
227227

228-
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
228+
seq_len = tokens.shape[-2]
229+
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
229230

230231
tokens = l2norm(tokens + pos_emb)
231232

232233
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
233234

234235
attn_out = l2norm(attn(tokens))
235-
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
236+
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
236237

237238
ff_out = l2norm(ff(tokens))
238-
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
239+
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
239240

240241
pooled = reduce(tokens, 'b n d -> b d', 'mean')
241242

242243
logits = self.to_pred(pooled)
243-
logits = logits * self.logit_scale * (self.dim ** 0.5)
244+
logits = logits * self.logit_scale * self.scale
244245

245246
return logits
246247

0 commit comments

Comments
 (0)