Skip to content

Commit a648a04

Browse files
committed
Supporting aimv2 encoders
1 parent 3a6661a commit a648a04

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

timm/layers/mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def __init__(
132132

133133
def init_weights(self):
134134
# override init of fc1 w/ gate portion set to weight near zero, bias=1
135-
nn.init.ones_(self.fc1_g.bias)
135+
if self.fc1_g.bias is not None:
136+
nn.init.ones_(self.fc1_g.bias)
136137
nn.init.normal_(self.fc1_g.weight, std=1e-6)
137138

138139
def forward(self, x):

timm/models/vision_transformer.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4545
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
4646
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
47-
get_act_layer, get_norm_layer, LayerType
47+
SwiGLU, get_act_layer, get_norm_layer, LayerType
4848
from ._builder import build_model_with_cfg
4949
from ._features import feature_take_indices
5050
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
@@ -65,6 +65,7 @@ def __init__(
6565
num_heads: int = 8,
6666
qkv_bias: bool = False,
6767
qk_norm: bool = False,
68+
proj_bias: bool = True,
6869
attn_drop: float = 0.,
6970
proj_drop: float = 0.,
7071
norm_layer: nn.Module = nn.LayerNorm,
@@ -80,7 +81,7 @@ def __init__(
8081
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
8182
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
8283
self.attn_drop = nn.Dropout(attn_drop)
83-
self.proj = nn.Linear(dim, dim)
84+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
8485
self.proj_drop = nn.Dropout(proj_drop)
8586

8687
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -130,6 +131,7 @@ def __init__(
130131
mlp_ratio: float = 4.,
131132
qkv_bias: bool = False,
132133
qk_norm: bool = False,
134+
proj_bias: bool = True,
133135
proj_drop: float = 0.,
134136
attn_drop: float = 0.,
135137
init_values: Optional[float] = None,
@@ -145,6 +147,7 @@ def __init__(
145147
num_heads=num_heads,
146148
qkv_bias=qkv_bias,
147149
qk_norm=qk_norm,
150+
proj_bias=proj_bias,
148151
attn_drop=attn_drop,
149152
proj_drop=proj_drop,
150153
norm_layer=norm_layer,
@@ -157,6 +160,7 @@ def __init__(
157160
in_features=dim,
158161
hidden_features=int(dim * mlp_ratio),
159162
act_layer=act_layer,
163+
bias=proj_bias,
160164
drop=proj_drop,
161165
)
162166
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
@@ -176,6 +180,7 @@ def __init__(
176180
mlp_ratio: float = 4.,
177181
qkv_bias: bool = False,
178182
qk_norm: bool = False,
183+
proj_bias: bool = True,
179184
proj_drop: float = 0.,
180185
attn_drop: float = 0.,
181186
init_values: Optional[float] = None,
@@ -192,6 +197,7 @@ def __init__(
192197
num_heads=num_heads,
193198
qkv_bias=qkv_bias,
194199
qk_norm=qk_norm,
200+
proj_bias=proj_bias,
195201
attn_drop=attn_drop,
196202
proj_drop=proj_drop,
197203
norm_layer=norm_layer,
@@ -203,6 +209,7 @@ def __init__(
203209
in_features=dim,
204210
hidden_features=int(dim * mlp_ratio),
205211
act_layer=act_layer,
212+
bias=proj_bias,
206213
drop=proj_drop,
207214
)
208215
self.norm2 = norm_layer(dim)
@@ -236,6 +243,7 @@ def __init__(
236243
mlp_ratio: float = 4.,
237244
qkv_bias: bool = False,
238245
qk_norm: bool = False,
246+
proj_bias: bool = True,
239247
proj_drop: float = 0.,
240248
attn_drop: float = 0.,
241249
init_values: Optional[float] = None,
@@ -266,11 +274,11 @@ def __init__(
266274
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
267275
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
268276
self.attn_drop = nn.Dropout(attn_drop)
269-
self.attn_out_proj = nn.Linear(dim, dim)
277+
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias)
270278

271279
self.mlp_drop = nn.Dropout(proj_drop)
272280
self.mlp_act = act_layer()
273-
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
281+
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias)
274282

275283
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
276284
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -330,6 +338,7 @@ def __init__(
330338
mlp_ratio: float = 4.,
331339
qkv_bias: bool = False,
332340
qk_norm: bool = False,
341+
proj_bias: bool = True,
333342
init_values: Optional[float] = None,
334343
proj_drop: float = 0.,
335344
attn_drop: float = 0.,
@@ -350,6 +359,7 @@ def __init__(
350359
num_heads=num_heads,
351360
qkv_bias=qkv_bias,
352361
qk_norm=qk_norm,
362+
proj_bias=proj_bias,
353363
attn_drop=attn_drop,
354364
proj_drop=proj_drop,
355365
norm_layer=norm_layer,
@@ -363,6 +373,7 @@ def __init__(
363373
dim,
364374
hidden_features=int(dim * mlp_ratio),
365375
act_layer=act_layer,
376+
bias=proj_bias,
366377
drop=proj_drop,
367378
)),
368379
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
@@ -433,6 +444,7 @@ def __init__(
433444
mlp_ratio: float = 4.,
434445
qkv_bias: bool = True,
435446
qk_norm: bool = False,
447+
proj_bias: bool = True,
436448
init_values: Optional[float] = None,
437449
class_token: bool = True,
438450
pos_embed: str = 'learn',
@@ -452,6 +464,7 @@ def __init__(
452464
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
453465
fix_init: bool = False,
454466
embed_layer: Callable = PatchEmbed,
467+
embed_norm_layer: Optional[LayerType] = None,
455468
norm_layer: Optional[LayerType] = None,
456469
act_layer: Optional[LayerType] = None,
457470
block_fn: Type[nn.Module] = Block,
@@ -483,6 +496,7 @@ def __init__(
483496
weight_init: Weight initialization scheme.
484497
fix_init: Apply weight initialization fix (scaling w/ layer index).
485498
embed_layer: Patch embedding layer.
499+
embed_norm_layer: Normalization layer to use / override in patch embed module.
486500
norm_layer: Normalization layer.
487501
act_layer: MLP activation layer.
488502
block_fn: Transformer block layer.
@@ -493,6 +507,7 @@ def __init__(
493507
assert pos_embed in ('', 'none', 'learn')
494508
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
495509
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
510+
embed_norm_layer = get_norm_layer(embed_norm_layer)
496511
act_layer = get_act_layer(act_layer) or nn.GELU
497512

498513
self.num_classes = num_classes
@@ -510,6 +525,8 @@ def __init__(
510525
if dynamic_img_size:
511526
# flatten deferred until after pos embed
512527
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
528+
if embed_norm_layer is not None:
529+
embed_args['norm_layer'] = embed_norm_layer
513530
self.patch_embed = embed_layer(
514531
img_size=img_size,
515532
patch_size=patch_size,
@@ -539,14 +556,15 @@ def __init__(
539556
self.patch_drop = nn.Identity()
540557
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
541558

542-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
559+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
543560
self.blocks = nn.Sequential(*[
544561
block_fn(
545562
dim=embed_dim,
546563
num_heads=num_heads,
547564
mlp_ratio=mlp_ratio,
548565
qkv_bias=qkv_bias,
549566
qk_norm=qk_norm,
567+
proj_bias=proj_bias,
550568
init_values=init_values,
551569
proj_drop=proj_drop_rate,
552570
attn_drop=attn_drop_rate,
@@ -1128,6 +1146,31 @@ def _convert_dinov2(
11281146
return out_dict
11291147

11301148

1149+
def _convert_aimv2(
1150+
state_dict: Dict[str, torch.Tensor],
1151+
model: VisionTransformer,
1152+
) -> Dict[str, torch.Tensor]:
1153+
#import re
1154+
out_dict = {}
1155+
1156+
for k, v in state_dict.items():
1157+
k = k.replace('norm_1', 'norm1')
1158+
k = k.replace('norm_2', 'norm2')
1159+
k = k.replace('preprocessor.patchifier.', 'patch_embed.')
1160+
k = k.replace('preprocessor.pos_embed', 'pos_embed')
1161+
k = k.replace('trunk.', '')
1162+
k = k.replace('mlp.fc1', 'mlp.fc1_g')
1163+
k = k.replace('mlp.fc3', 'mlp.fc1_x')
1164+
k = k.replace('post_trunk_norm.', 'norm.')
1165+
# if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166+
# out_dict[k.replace("w12", "fc1")] = v
1167+
# continue
1168+
# elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169+
# out_dict[k.replace("w3", "fc2")] = v
1170+
# continue
1171+
out_dict[k] = v
1172+
return out_dict
1173+
11311174
def checkpoint_filter_fn(
11321175
state_dict: Dict[str, torch.Tensor],
11331176
model: VisionTransformer,
@@ -1159,6 +1202,8 @@ def checkpoint_filter_fn(
11591202
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
11601203
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
11611204
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
1205+
elif 'preprocessor.patchifier.proj.weight' in state_dict:
1206+
state_dict = _convert_aimv2(state_dict, model)
11621207

11631208
if prefix:
11641209
# filter on & remove prefix string from keys
@@ -2119,6 +2164,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21192164
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
21202165
),
21212166

2167+
'vit_large_patch14_aimv2_224': _cfg(
2168+
hf_hub_id='apple/aimv2-large-patch14-224',
2169+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
2170+
input_size=(3, 224, 224), crop_pct=1.0,
2171+
num_classes=0),
2172+
21222173
'test_vit.r160_in1k': _cfg(
21232174
hf_hub_id='timm/',
21242175
input_size=(3, 160, 160), crop_pct=0.95),
@@ -3390,6 +3441,21 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
33903441
return model
33913442

33923443

3444+
@register_model
3445+
def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3446+
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
3447+
"""
3448+
rms_norm = partial(RmsNorm, eps=1e-5)
3449+
model_args = dict(
3450+
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
3451+
mlp_ratio=2.75, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLU,
3452+
qkv_bias=False, proj_bias=False,
3453+
)
3454+
model = _create_vision_transformer(
3455+
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))
3456+
return model
3457+
3458+
33933459
@register_model
33943460
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
33953461
""" ViT Test

0 commit comments

Comments
 (0)