Skip to content

Commit 5804d92

Browse files
committed
Switch aimv2 to used packed SwiGLU
1 parent 15406a9 commit 5804d92

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

timm/layers/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def __init__(
8383

8484
def init_weights(self):
8585
# override init of fc1 w/ gate portion set to weight near zero, bias=1
86-
fc1_mid = self.fc1.bias.shape[0] // 2
87-
nn.init.ones_(self.fc1.bias[fc1_mid:])
88-
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
86+
if self.fc1.bias is not None:
87+
nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
88+
nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)
8989

9090
def forward(self, x):
9191
x = self.fc1(x)

timm/models/vision_transformer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,25 +1150,25 @@ def _convert_aimv2(
11501150
state_dict: Dict[str, torch.Tensor],
11511151
model: VisionTransformer,
11521152
) -> Dict[str, torch.Tensor]:
1153-
#import re
11541153
out_dict = {}
1155-
11561154
for k, v in state_dict.items():
11571155
k = k.replace('norm_1', 'norm1')
11581156
k = k.replace('norm_2', 'norm2')
11591157
k = k.replace('preprocessor.patchifier.', 'patch_embed.')
11601158
k = k.replace('preprocessor.pos_embed', 'pos_embed')
11611159
k = k.replace('trunk.', '')
1162-
k = k.replace('mlp.fc1', 'mlp.fc1_g')
1163-
k = k.replace('mlp.fc3', 'mlp.fc1_x')
11641160
k = k.replace('post_trunk_norm.', 'norm.')
1165-
# if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166-
# out_dict[k.replace("w12", "fc1")] = v
1167-
# continue
1168-
# elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169-
# out_dict[k.replace("w3", "fc2")] = v
1170-
# continue
1161+
1162+
if 'mlp.fc1' in k:
1163+
if k in out_dict:
1164+
v = torch.cat([v, out_dict[k]], dim=0)
1165+
elif 'mlp.fc3' in k:
1166+
k = k.replace('mlp.fc3', 'mlp.fc1')
1167+
if k in out_dict:
1168+
v = torch.cat([out_dict[k], v], dim=0)
1169+
11711170
out_dict[k] = v
1171+
11721172
return out_dict
11731173

11741174
def checkpoint_filter_fn(
@@ -3448,8 +3448,8 @@ def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTra
34483448
rms_norm = partial(RmsNorm, eps=1e-5)
34493449
model_args = dict(
34503450
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
3451-
mlp_ratio=2.75, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLU,
3452-
qkv_bias=False, proj_bias=False,
3451+
mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked,
3452+
qkv_bias=False, proj_bias=False, act_layer='silu'
34533453
)
34543454
model = _create_vision_transformer(
34553455
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)