@@ -179,18 +179,18 @@ def __init__(
179
179
180
180
self .to_patch_embedding = nn .Sequential (
181
181
Rearrange ('b c (h p1) (w p2) -> b (h w) (c p1 p2)' , p1 = patch_size , p2 = patch_size ),
182
- nn .LayerNorm (patch_dim ),
183
- nn .Linear (patch_dim , dim ),
184
- nn .LayerNorm (dim ),
182
+ NormLinear (patch_dim , dim , norm_dim_in = False ),
185
183
)
186
184
187
- self .abs_pos_emb = nn . Embedding ( num_patches , dim )
185
+ self .abs_pos_emb = NormLinear ( dim , num_patches )
188
186
189
187
residual_lerp_scale_init = default (residual_lerp_scale_init , 1. / depth )
190
188
191
189
# layers
192
190
193
191
self .dim = dim
192
+ self .scale = dim ** 0.5
193
+
194
194
self .layers = ModuleList ([])
195
195
self .residual_lerp_scales = nn .ParameterList ([])
196
196
@@ -201,8 +201,8 @@ def __init__(
201
201
]))
202
202
203
203
self .residual_lerp_scales .append (nn .ParameterList ([
204
- nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
205
- nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
204
+ nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init / self . scale ),
205
+ nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init / self . scale ),
206
206
]))
207
207
208
208
self .logit_scale = nn .Parameter (torch .ones (num_classes ))
@@ -225,22 +225,23 @@ def forward(self, images):
225
225
226
226
tokens = self .to_patch_embedding (images )
227
227
228
- pos_emb = self .abs_pos_emb (torch .arange (tokens .shape [- 2 ], device = device ))
228
+ seq_len = tokens .shape [- 2 ]
229
+ pos_emb = self .abs_pos_emb .weight [torch .arange (seq_len , device = device )]
229
230
230
231
tokens = l2norm (tokens + pos_emb )
231
232
232
233
for (attn , ff ), (attn_alpha , ff_alpha ) in zip (self .layers , self .residual_lerp_scales ):
233
234
234
235
attn_out = l2norm (attn (tokens ))
235
- tokens = l2norm (tokens .lerp (attn_out , attn_alpha ))
236
+ tokens = l2norm (tokens .lerp (attn_out , attn_alpha * self . scale ))
236
237
237
238
ff_out = l2norm (ff (tokens ))
238
- tokens = l2norm (tokens .lerp (ff_out , ff_alpha ))
239
+ tokens = l2norm (tokens .lerp (ff_out , ff_alpha * self . scale ))
239
240
240
241
pooled = reduce (tokens , 'b n d -> b d' , 'mean' )
241
242
242
243
logits = self .to_pred (pooled )
243
- logits = logits * self .logit_scale * ( self .dim ** 0.5 )
244
+ logits = logits * self .logit_scale * self .scale
244
245
245
246
return logits
246
247
0 commit comments