Skip to content

Commit 4209788

Browse files
committed
Updated faster models w/ hub weight locations, commented out some checkpoint filter fns and minor renames
1 parent 75823ab commit 4209788

File tree

5 files changed

+125
-124
lines changed

5 files changed

+125
-124
lines changed

timm/models/fasternet.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -369,32 +369,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
369369

370370

371371
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
372-
if 'avgpool_pre_head' in state_dict:
373-
return state_dict
374-
375-
out_dict = {
376-
'conv_head.weight': state_dict.pop('avgpool_pre_head.1.weight'),
377-
'classifier.weight': state_dict.pop('head.weight'),
378-
'classifier.bias': state_dict.pop('head.bias')
379-
}
380-
381-
stage_mapping = {
382-
'stages.1.': 'stages.1.downsample.',
383-
'stages.2.': 'stages.1.',
384-
'stages.3.': 'stages.2.downsample.',
385-
'stages.4.': 'stages.2.',
386-
'stages.5.': 'stages.3.downsample.',
387-
'stages.6.': 'stages.3.'
388-
}
389-
390-
for k, v in state_dict.items():
391-
for old_prefix, new_prefix in stage_mapping.items():
392-
if k.startswith(old_prefix):
393-
k = k.replace(old_prefix, new_prefix)
394-
break
395-
out_dict[k] = v
396-
397-
return out_dict
372+
# if 'avgpool_pre_head' in state_dict:
373+
# return state_dict
374+
#
375+
# out_dict = {
376+
# 'conv_head.weight': state_dict.pop('avgpool_pre_head.1.weight'),
377+
# 'classifier.weight': state_dict.pop('head.weight'),
378+
# 'classifier.bias': state_dict.pop('head.bias')
379+
# }
380+
#
381+
# stage_mapping = {
382+
# 'stages.1.': 'stages.1.downsample.',
383+
# 'stages.2.': 'stages.1.',
384+
# 'stages.3.': 'stages.2.downsample.',
385+
# 'stages.4.': 'stages.2.',
386+
# 'stages.5.': 'stages.3.downsample.',
387+
# 'stages.6.': 'stages.3.'
388+
# }
389+
#
390+
# for k, v in state_dict.items():
391+
# for old_prefix, new_prefix in stage_mapping.items():
392+
# if k.startswith(old_prefix):
393+
# k = k.replace(old_prefix, new_prefix)
394+
# break
395+
# out_dict[k] = v
396+
return state_dict
398397

399398

400399
def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
@@ -412,28 +411,28 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
412411

413412
default_cfgs = generate_default_cfgs({
414413
'fasternet_t0.in1k': _cfg(
415-
# hf_hub_id='timm/',
416-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t0-epoch.281-val_acc1.71.9180.pth',
414+
hf_hub_id='timm/',
415+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t0-epoch.281-val_acc1.71.9180.pth',
417416
),
418417
'fasternet_t1.in1k': _cfg(
419-
# hf_hub_id='timm/',
420-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t1-epoch.291-val_acc1.76.2180.pth',
418+
hf_hub_id='timm/',
419+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t1-epoch.291-val_acc1.76.2180.pth',
421420
),
422421
'fasternet_t2.in1k': _cfg(
423-
# hf_hub_id='timm/',
424-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t2-epoch.289-val_acc1.78.8860.pth',
422+
hf_hub_id='timm/',
423+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t2-epoch.289-val_acc1.78.8860.pth',
425424
),
426425
'fasternet_s.in1k': _cfg(
427-
# hf_hub_id='timm/',
428-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_s-epoch.299-val_acc1.81.2840.pth',
426+
hf_hub_id='timm/',
427+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_s-epoch.299-val_acc1.81.2840.pth',
429428
),
430429
'fasternet_m.in1k': _cfg(
431-
# hf_hub_id='timm/',
432-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_m-epoch.291-val_acc1.82.9620.pth',
430+
hf_hub_id='timm/',
431+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_m-epoch.291-val_acc1.82.9620.pth',
433432
),
434433
'fasternet_l.in1k': _cfg(
435-
# hf_hub_id='timm/',
436-
url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_l-epoch.299-val_acc1.83.5060.pth',
434+
hf_hub_id='timm/',
435+
#url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_l-epoch.299-val_acc1.83.5060.pth',
437436
),
438437
})
439438

timm/models/ghostnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,8 +872,8 @@ def _cfg(url='', **kwargs):
872872
),
873873
'ghostnetv3_050.untrained': _cfg(),
874874
'ghostnetv3_100.in1k': _cfg(
875-
# hf_hub_id='timm/',
876-
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV3/ghostnetv3-1.0.pth.tar'
875+
hf_hub_id='timm/',
876+
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV3/ghostnetv3-1.0.pth.tar'
877877
),
878878
'ghostnetv3_130.untrained': _cfg(),
879879
'ghostnetv3_160.untrained': _cfg(),

timm/models/shvit.py

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
__all__ = ['SHViT']
2929

3030

31-
class Residule(nn.Module):
31+
class Residual(nn.Module):
3232
def __init__(self, m: nn.Module):
3333
super().__init__()
3434
self.m = m
@@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3838

3939
@torch.no_grad()
4040
def fuse(self) -> nn.Module:
41-
if isinstance(self.m, Conv2d_BN):
41+
if isinstance(self.m, Conv2dNorm):
4242
m = self.m.fuse()
4343
assert(m.groups == m.in_channels)
4444
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
@@ -49,7 +49,7 @@ def fuse(self) -> nn.Module:
4949
return self
5050

5151

52-
class Conv2d_BN(nn.Sequential):
52+
class Conv2dNorm(nn.Sequential):
5353
def __init__(
5454
self,
5555
in_channels: int,
@@ -89,7 +89,7 @@ def fuse(self) -> nn.Conv2d:
8989
return m
9090

9191

92-
class BN_Linear(nn.Sequential):
92+
class NormLinear(nn.Sequential):
9393
def __init__(
9494
self,
9595
in_features: int,
@@ -124,12 +124,12 @@ class PatchMerging(nn.Module):
124124
def __init__(self, dim: int, out_dim: int, act_layer: LayerType = nn.ReLU):
125125
super().__init__()
126126
hid_dim = int(dim * 4)
127-
self.conv1 = Conv2d_BN(dim, hid_dim)
127+
self.conv1 = Conv2dNorm(dim, hid_dim)
128128
self.act1 = act_layer()
129-
self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
129+
self.conv2 = Conv2dNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
130130
self.act2 = act_layer()
131131
self.se = SqueezeExcite(hid_dim, 0.25)
132-
self.conv3 = Conv2d_BN(hid_dim, out_dim)
132+
self.conv3 = Conv2dNorm(hid_dim, out_dim)
133133

134134
def forward(self, x: torch.Tensor) -> torch.Tensor:
135135
x = self.conv1(x)
@@ -144,9 +144,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
144144
class FFN(nn.Module):
145145
def __init__(self, dim: int, embed_dim: int, act_layer: LayerType = nn.ReLU):
146146
super().__init__()
147-
self.pw1 = Conv2d_BN(dim, embed_dim)
147+
self.pw1 = Conv2dNorm(dim, embed_dim)
148148
self.act = act_layer()
149-
self.pw2 = Conv2d_BN(embed_dim, dim, bn_weight_init=0)
149+
self.pw2 = Conv2dNorm(embed_dim, dim, bn_weight_init=0)
150150

151151
def forward(self, x: torch.Tensor) -> torch.Tensor:
152152
x = self.pw1(x)
@@ -173,8 +173,8 @@ def __init__(
173173

174174
self.pre_norm = norm_layer(pdim)
175175

176-
self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
177-
self.proj = nn.Sequential(act_layer(), Conv2d_BN(dim, dim, bn_weight_init=0))
176+
self.qkv = Conv2dNorm(pdim, qk_dim * 2 + pdim)
177+
self.proj = nn.Sequential(act_layer(), Conv2dNorm(dim, dim, bn_weight_init=0))
178178

179179
def forward(self, x: torch.Tensor) -> torch.Tensor:
180180
B, _, H, W = x.shape
@@ -202,12 +202,12 @@ def __init__(
202202
act_layer: LayerType = nn.ReLU,
203203
):
204204
super().__init__()
205-
self.conv = Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0))
205+
self.conv = Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0))
206206
if type == "s":
207-
self.mixer = Residule(SHSA(dim, qk_dim, pdim, norm_layer, act_layer))
207+
self.mixer = Residual(SHSA(dim, qk_dim, pdim, norm_layer, act_layer))
208208
else:
209209
self.mixer = nn.Identity()
210-
self.ffn = Residule(FFN(dim, int(dim * 2)))
210+
self.ffn = Residual(FFN(dim, int(dim * 2)))
211211

212212
def forward(self, x: torch.Tensor) -> torch.Tensor:
213213
x = self.conv(x)
@@ -231,11 +231,11 @@ def __init__(
231231
super().__init__()
232232
self.grad_checkpointing = False
233233
self.downsample = nn.Sequential(
234-
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
235-
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
234+
Residual(Conv2dNorm(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
235+
Residual(FFN(prev_dim, int(prev_dim * 2), act_layer)),
236236
PatchMerging(prev_dim, dim, act_layer),
237-
Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim)),
238-
Residule(FFN(dim, int(dim * 2), act_layer)),
237+
Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim)),
238+
Residual(FFN(dim, int(dim * 2), act_layer)),
239239
) if prev_dim != dim else nn.Identity()
240240

241241
self.blocks = nn.Sequential(*[
@@ -274,13 +274,13 @@ def __init__(
274274
# Patch embedding
275275
stem_chs = embed_dim[0]
276276
self.patch_embed = nn.Sequential(
277-
Conv2d_BN(in_chans, stem_chs // 8, 3, 2, 1),
277+
Conv2dNorm(in_chans, stem_chs // 8, 3, 2, 1),
278278
act_layer(),
279-
Conv2d_BN(stem_chs // 8, stem_chs // 4, 3, 2, 1),
279+
Conv2dNorm(stem_chs // 8, stem_chs // 4, 3, 2, 1),
280280
act_layer(),
281-
Conv2d_BN(stem_chs // 4, stem_chs // 2, 3, 2, 1),
281+
Conv2dNorm(stem_chs // 4, stem_chs // 2, 3, 2, 1),
282282
act_layer(),
283-
Conv2d_BN(stem_chs // 2, stem_chs, 3, 2, 1)
283+
Conv2dNorm(stem_chs // 2, stem_chs, 3, 2, 1)
284284
)
285285

286286
# Build SHViT blocks
@@ -305,7 +305,7 @@ def __init__(
305305
self.num_features = self.head_hidden_size = embed_dim[-1]
306306
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
307307
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
308-
self.head = BN_Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
308+
self.head = NormLinear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
309309

310310
@torch.jit.ignore
311311
def no_weight_decay(self) -> Set:
@@ -336,7 +336,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
336336
# cannot meaningfully change pooling of efficient head after creation
337337
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
338338
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
339-
self.head = BN_Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
339+
self.head = NormLinear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
340340

341341
def forward_intermediates(
342342
self,
@@ -426,36 +426,36 @@ def fuse_children(net):
426426

427427

428428
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
429-
if 'model' in state_dict:
430-
state_dict = state_dict['model']
431-
out_dict = {}
432-
433-
replace_rules = [
434-
(re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
435-
(re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
436-
(re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
437-
]
438-
downsample_mapping = {}
439-
for i in range(1, 3):
440-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
441-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
442-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
443-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
444-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
445-
for j in range(3, 10):
446-
downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
447-
448-
downsample_patterns = [
449-
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
450-
451-
for k, v in state_dict.items():
452-
for pattern, replacement in replace_rules:
453-
k = pattern.sub(replacement, k)
454-
for pattern, replacement in downsample_patterns:
455-
k = pattern.sub(replacement, k)
456-
out_dict[k] = v
457-
458-
return out_dict
429+
state_dict = state_dict.get('model', state_dict)
430+
431+
# out_dict = {}
432+
#
433+
# replace_rules = [
434+
# (re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
435+
# (re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
436+
# (re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
437+
# ]
438+
# downsample_mapping = {}
439+
# for i in range(1, 3):
440+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
441+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
442+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
443+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
444+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
445+
# for j in range(3, 10):
446+
# downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
447+
#
448+
# downsample_patterns = [
449+
# (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
450+
#
451+
# for k, v in state_dict.items():
452+
# for pattern, replacement in replace_rules:
453+
# k = pattern.sub(replacement, k)
454+
# for pattern, replacement in downsample_patterns:
455+
# k = pattern.sub(replacement, k)
456+
# out_dict[k] = v
457+
458+
return state_dict
459459

460460

461461
def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
@@ -473,20 +473,20 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
473473

474474
default_cfgs = generate_default_cfgs({
475475
'shvit_s1.in1k': _cfg(
476-
# hf_hub_id='timm/',
477-
url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth',
476+
hf_hub_id='timm/',
477+
#url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth',
478478
),
479479
'shvit_s2.in1k': _cfg(
480-
# hf_hub_id='timm/',
481-
url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth',
480+
hf_hub_id='timm/',
481+
#url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth',
482482
),
483483
'shvit_s3.in1k': _cfg(
484-
# hf_hub_id='timm/',
485-
url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth',
484+
hf_hub_id='timm/',
485+
#url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth',
486486
),
487487
'shvit_s4.in1k': _cfg(
488-
# hf_hub_id='timm/',
489-
url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth',
488+
hf_hub_id='timm/',
489+
#url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth',
490490
input_size=(3, 256, 256),
491491
),
492492
})

timm/models/starnet.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
253253

254254

255255
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
256-
if 'state_dict' in state_dict:
257-
state_dict = state_dict['state_dict']
258-
out_dict = state_dict
259-
return out_dict
256+
return state_dict.get('state_dict', state_dict)
260257

261258

262259
def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
@@ -274,20 +271,20 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
274271

275272
default_cfgs = generate_default_cfgs({
276273
'starnet_s1.in1k': _cfg(
277-
# hf_hub_id='timm/',
278-
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar',
274+
hf_hub_id='timm/',
275+
#url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar',
279276
),
280277
'starnet_s2.in1k': _cfg(
281-
# hf_hub_id='timm/',
282-
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar',
278+
hf_hub_id='timm/',
279+
#url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar',
283280
),
284281
'starnet_s3.in1k': _cfg(
285-
# hf_hub_id='timm/',
286-
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar',
282+
hf_hub_id='timm/',
283+
#url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar',
287284
),
288285
'starnet_s4.in1k': _cfg(
289-
# hf_hub_id='timm/',
290-
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar',
286+
hf_hub_id='timm/',
287+
#url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar',
291288
),
292289
'starnet_s050.untrained': _cfg(),
293290
'starnet_s100.untrained': _cfg(),

0 commit comments

Comments
 (0)