@@ -1150,25 +1150,25 @@ def _convert_aimv2(
1150
1150
state_dict : Dict [str , torch .Tensor ],
1151
1151
model : VisionTransformer ,
1152
1152
) -> Dict [str , torch .Tensor ]:
1153
- #import re
1154
1153
out_dict = {}
1155
-
1156
1154
for k , v in state_dict .items ():
1157
1155
k = k .replace ('norm_1' , 'norm1' )
1158
1156
k = k .replace ('norm_2' , 'norm2' )
1159
1157
k = k .replace ('preprocessor.patchifier.' , 'patch_embed.' )
1160
1158
k = k .replace ('preprocessor.pos_embed' , 'pos_embed' )
1161
1159
k = k .replace ('trunk.' , '' )
1162
- k = k .replace ('mlp.fc1' , 'mlp.fc1_g' )
1163
- k = k .replace ('mlp.fc3' , 'mlp.fc1_x' )
1164
1160
k = k .replace ('post_trunk_norm.' , 'norm.' )
1165
- # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166
- # out_dict[k.replace("w12", "fc1")] = v
1167
- # continue
1168
- # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169
- # out_dict[k.replace("w3", "fc2")] = v
1170
- # continue
1161
+
1162
+ if 'mlp.fc1' in k :
1163
+ if k in out_dict :
1164
+ v = torch .cat ([v , out_dict [k ]], dim = 0 )
1165
+ elif 'mlp.fc3' in k :
1166
+ k = k .replace ('mlp.fc3' , 'mlp.fc1' )
1167
+ if k in out_dict :
1168
+ v = torch .cat ([out_dict [k ], v ], dim = 0 )
1169
+
1171
1170
out_dict [k ] = v
1171
+
1172
1172
return out_dict
1173
1173
1174
1174
def checkpoint_filter_fn (
@@ -3448,8 +3448,8 @@ def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTra
3448
3448
rms_norm = partial (RmsNorm , eps = 1e-5 )
3449
3449
model_args = dict (
3450
3450
patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , fc_norm = False ,
3451
- mlp_ratio = 2.75 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLU ,
3452
- qkv_bias = False , proj_bias = False ,
3451
+ mlp_ratio = 5.5 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLUPacked ,
3452
+ qkv_bias = False , proj_bias = False , act_layer = 'silu'
3453
3453
)
3454
3454
model = _create_vision_transformer (
3455
3455
'vit_large_patch14_aimv2_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments