44
44
OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
45
45
from timm .layers import PatchEmbed , Mlp , DropPath , AttentionPoolLatent , RmsNorm , PatchDropout , SwiGLUPacked , \
46
46
trunc_normal_ , lecun_normal_ , resample_patch_embed , resample_abs_pos_embed , use_fused_attn , \
47
- get_act_layer , get_norm_layer , LayerType
47
+ SwiGLU , get_act_layer , get_norm_layer , LayerType
48
48
from ._builder import build_model_with_cfg
49
49
from ._features import feature_take_indices
50
50
from ._manipulate import named_apply , checkpoint_seq , adapt_input_conv
@@ -65,6 +65,7 @@ def __init__(
65
65
num_heads : int = 8 ,
66
66
qkv_bias : bool = False ,
67
67
qk_norm : bool = False ,
68
+ proj_bias : bool = True ,
68
69
attn_drop : float = 0. ,
69
70
proj_drop : float = 0. ,
70
71
norm_layer : nn .Module = nn .LayerNorm ,
@@ -80,7 +81,7 @@ def __init__(
80
81
self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
81
82
self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
82
83
self .attn_drop = nn .Dropout (attn_drop )
83
- self .proj = nn .Linear (dim , dim )
84
+ self .proj = nn .Linear (dim , dim , bias = proj_bias )
84
85
self .proj_drop = nn .Dropout (proj_drop )
85
86
86
87
def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -130,6 +131,7 @@ def __init__(
130
131
mlp_ratio : float = 4. ,
131
132
qkv_bias : bool = False ,
132
133
qk_norm : bool = False ,
134
+ proj_bias : bool = True ,
133
135
proj_drop : float = 0. ,
134
136
attn_drop : float = 0. ,
135
137
init_values : Optional [float ] = None ,
@@ -145,6 +147,7 @@ def __init__(
145
147
num_heads = num_heads ,
146
148
qkv_bias = qkv_bias ,
147
149
qk_norm = qk_norm ,
150
+ proj_bias = proj_bias ,
148
151
attn_drop = attn_drop ,
149
152
proj_drop = proj_drop ,
150
153
norm_layer = norm_layer ,
@@ -157,6 +160,7 @@ def __init__(
157
160
in_features = dim ,
158
161
hidden_features = int (dim * mlp_ratio ),
159
162
act_layer = act_layer ,
163
+ bias = proj_bias ,
160
164
drop = proj_drop ,
161
165
)
162
166
self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
@@ -176,6 +180,7 @@ def __init__(
176
180
mlp_ratio : float = 4. ,
177
181
qkv_bias : bool = False ,
178
182
qk_norm : bool = False ,
183
+ proj_bias : bool = True ,
179
184
proj_drop : float = 0. ,
180
185
attn_drop : float = 0. ,
181
186
init_values : Optional [float ] = None ,
@@ -192,6 +197,7 @@ def __init__(
192
197
num_heads = num_heads ,
193
198
qkv_bias = qkv_bias ,
194
199
qk_norm = qk_norm ,
200
+ proj_bias = proj_bias ,
195
201
attn_drop = attn_drop ,
196
202
proj_drop = proj_drop ,
197
203
norm_layer = norm_layer ,
@@ -203,6 +209,7 @@ def __init__(
203
209
in_features = dim ,
204
210
hidden_features = int (dim * mlp_ratio ),
205
211
act_layer = act_layer ,
212
+ bias = proj_bias ,
206
213
drop = proj_drop ,
207
214
)
208
215
self .norm2 = norm_layer (dim )
@@ -236,6 +243,7 @@ def __init__(
236
243
mlp_ratio : float = 4. ,
237
244
qkv_bias : bool = False ,
238
245
qk_norm : bool = False ,
246
+ proj_bias : bool = True ,
239
247
proj_drop : float = 0. ,
240
248
attn_drop : float = 0. ,
241
249
init_values : Optional [float ] = None ,
@@ -266,11 +274,11 @@ def __init__(
266
274
self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
267
275
self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
268
276
self .attn_drop = nn .Dropout (attn_drop )
269
- self .attn_out_proj = nn .Linear (dim , dim )
277
+ self .attn_out_proj = nn .Linear (dim , dim , bias = proj_bias )
270
278
271
279
self .mlp_drop = nn .Dropout (proj_drop )
272
280
self .mlp_act = act_layer ()
273
- self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim )
281
+ self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim , bias = proj_bias )
274
282
275
283
self .ls = LayerScale (dim , init_values = init_values ) if init_values is not None else nn .Identity ()
276
284
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -330,6 +338,7 @@ def __init__(
330
338
mlp_ratio : float = 4. ,
331
339
qkv_bias : bool = False ,
332
340
qk_norm : bool = False ,
341
+ proj_bias : bool = True ,
333
342
init_values : Optional [float ] = None ,
334
343
proj_drop : float = 0. ,
335
344
attn_drop : float = 0. ,
@@ -350,6 +359,7 @@ def __init__(
350
359
num_heads = num_heads ,
351
360
qkv_bias = qkv_bias ,
352
361
qk_norm = qk_norm ,
362
+ proj_bias = proj_bias ,
353
363
attn_drop = attn_drop ,
354
364
proj_drop = proj_drop ,
355
365
norm_layer = norm_layer ,
@@ -363,6 +373,7 @@ def __init__(
363
373
dim ,
364
374
hidden_features = int (dim * mlp_ratio ),
365
375
act_layer = act_layer ,
376
+ bias = proj_bias ,
366
377
drop = proj_drop ,
367
378
)),
368
379
('ls' , LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()),
@@ -433,6 +444,7 @@ def __init__(
433
444
mlp_ratio : float = 4. ,
434
445
qkv_bias : bool = True ,
435
446
qk_norm : bool = False ,
447
+ proj_bias : bool = True ,
436
448
init_values : Optional [float ] = None ,
437
449
class_token : bool = True ,
438
450
pos_embed : str = 'learn' ,
@@ -452,6 +464,7 @@ def __init__(
452
464
weight_init : Literal ['skip' , 'jax' , 'jax_nlhb' , 'moco' , '' ] = '' ,
453
465
fix_init : bool = False ,
454
466
embed_layer : Callable = PatchEmbed ,
467
+ embed_norm_layer : Optional [LayerType ] = None ,
455
468
norm_layer : Optional [LayerType ] = None ,
456
469
act_layer : Optional [LayerType ] = None ,
457
470
block_fn : Type [nn .Module ] = Block ,
@@ -483,6 +496,7 @@ def __init__(
483
496
weight_init: Weight initialization scheme.
484
497
fix_init: Apply weight initialization fix (scaling w/ layer index).
485
498
embed_layer: Patch embedding layer.
499
+ embed_norm_layer: Normalization layer to use / override in patch embed module.
486
500
norm_layer: Normalization layer.
487
501
act_layer: MLP activation layer.
488
502
block_fn: Transformer block layer.
@@ -493,6 +507,7 @@ def __init__(
493
507
assert pos_embed in ('' , 'none' , 'learn' )
494
508
use_fc_norm = global_pool in ('avg' , 'avgmax' , 'max' ) if fc_norm is None else fc_norm
495
509
norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
510
+ embed_norm_layer = get_norm_layer (embed_norm_layer )
496
511
act_layer = get_act_layer (act_layer ) or nn .GELU
497
512
498
513
self .num_classes = num_classes
@@ -510,6 +525,8 @@ def __init__(
510
525
if dynamic_img_size :
511
526
# flatten deferred until after pos embed
512
527
embed_args .update (dict (strict_img_size = False , output_fmt = 'NHWC' ))
528
+ if embed_norm_layer is not None :
529
+ embed_args ['norm_layer' ] = embed_norm_layer
513
530
self .patch_embed = embed_layer (
514
531
img_size = img_size ,
515
532
patch_size = patch_size ,
@@ -539,14 +556,15 @@ def __init__(
539
556
self .patch_drop = nn .Identity ()
540
557
self .norm_pre = norm_layer (embed_dim ) if pre_norm else nn .Identity ()
541
558
542
- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
559
+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth , device = 'cpu' )] # stochastic depth decay rule
543
560
self .blocks = nn .Sequential (* [
544
561
block_fn (
545
562
dim = embed_dim ,
546
563
num_heads = num_heads ,
547
564
mlp_ratio = mlp_ratio ,
548
565
qkv_bias = qkv_bias ,
549
566
qk_norm = qk_norm ,
567
+ proj_bias = proj_bias ,
550
568
init_values = init_values ,
551
569
proj_drop = proj_drop_rate ,
552
570
attn_drop = attn_drop_rate ,
@@ -1128,6 +1146,31 @@ def _convert_dinov2(
1128
1146
return out_dict
1129
1147
1130
1148
1149
+ def _convert_aimv2 (
1150
+ state_dict : Dict [str , torch .Tensor ],
1151
+ model : VisionTransformer ,
1152
+ ) -> Dict [str , torch .Tensor ]:
1153
+ #import re
1154
+ out_dict = {}
1155
+
1156
+ for k , v in state_dict .items ():
1157
+ k = k .replace ('norm_1' , 'norm1' )
1158
+ k = k .replace ('norm_2' , 'norm2' )
1159
+ k = k .replace ('preprocessor.patchifier.' , 'patch_embed.' )
1160
+ k = k .replace ('preprocessor.pos_embed' , 'pos_embed' )
1161
+ k = k .replace ('trunk.' , '' )
1162
+ k = k .replace ('mlp.fc1' , 'mlp.fc1_g' )
1163
+ k = k .replace ('mlp.fc3' , 'mlp.fc1_x' )
1164
+ k = k .replace ('post_trunk_norm.' , 'norm.' )
1165
+ # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166
+ # out_dict[k.replace("w12", "fc1")] = v
1167
+ # continue
1168
+ # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169
+ # out_dict[k.replace("w3", "fc2")] = v
1170
+ # continue
1171
+ out_dict [k ] = v
1172
+ return out_dict
1173
+
1131
1174
def checkpoint_filter_fn (
1132
1175
state_dict : Dict [str , torch .Tensor ],
1133
1176
model : VisionTransformer ,
@@ -1159,6 +1202,8 @@ def checkpoint_filter_fn(
1159
1202
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1160
1203
out_dict ['head.weight' ] = state_dict ['visual.head.proj.weight' ]
1161
1204
out_dict ['head.bias' ] = torch .zeros (state_dict ['visual.head.proj.weight' ].shape [0 ])
1205
+ elif 'preprocessor.patchifier.proj.weight' in state_dict :
1206
+ state_dict = _convert_aimv2 (state_dict , model )
1162
1207
1163
1208
if prefix :
1164
1209
# filter on & remove prefix string from keys
@@ -2119,6 +2164,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
2119
2164
input_size = (3 , 448 , 448 ), crop_pct = 1.0 , num_classes = 0 ,
2120
2165
),
2121
2166
2167
+ 'vit_large_patch14_aimv2_224' : _cfg (
2168
+ hf_hub_id = 'apple/aimv2-large-patch14-224' ,
2169
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
2170
+ input_size = (3 , 224 , 224 ), crop_pct = 1.0 ,
2171
+ num_classes = 0 ),
2172
+
2122
2173
'test_vit.r160_in1k' : _cfg (
2123
2174
hf_hub_id = 'timm/' ,
2124
2175
input_size = (3 , 160 , 160 ), crop_pct = 0.95 ),
@@ -3390,6 +3441,21 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
3390
3441
return model
3391
3442
3392
3443
3444
+ @register_model
3445
+ def vit_large_patch14_aimv2_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3446
+ """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
3447
+ """
3448
+ rms_norm = partial (RmsNorm , eps = 1e-5 )
3449
+ model_args = dict (
3450
+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , fc_norm = False ,
3451
+ mlp_ratio = 2.75 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLU ,
3452
+ qkv_bias = False , proj_bias = False ,
3453
+ )
3454
+ model = _create_vision_transformer (
3455
+ 'vit_large_patch14_aimv2_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3456
+ return model
3457
+
3458
+
3393
3459
@register_model
3394
3460
def test_vit (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3395
3461
""" ViT Test
0 commit comments