28
28
__all__ = ['SHViT' ]
29
29
30
30
31
- class Residule (nn .Module ):
31
+ class Residual (nn .Module ):
32
32
def __init__ (self , m : nn .Module ):
33
33
super ().__init__ ()
34
34
self .m = m
@@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
38
39
39
@torch .no_grad ()
40
40
def fuse (self ) -> nn .Module :
41
- if isinstance (self .m , Conv2d_BN ):
41
+ if isinstance (self .m , Conv2dNorm ):
42
42
m = self .m .fuse ()
43
43
assert (m .groups == m .in_channels )
44
44
identity = torch .ones (m .weight .shape [0 ], m .weight .shape [1 ], 1 , 1 )
@@ -49,7 +49,7 @@ def fuse(self) -> nn.Module:
49
49
return self
50
50
51
51
52
- class Conv2d_BN (nn .Sequential ):
52
+ class Conv2dNorm (nn .Sequential ):
53
53
def __init__ (
54
54
self ,
55
55
in_channels : int ,
@@ -89,7 +89,7 @@ def fuse(self) -> nn.Conv2d:
89
89
return m
90
90
91
91
92
- class BN_Linear (nn .Sequential ):
92
+ class NormLinear (nn .Sequential ):
93
93
def __init__ (
94
94
self ,
95
95
in_features : int ,
@@ -124,12 +124,12 @@ class PatchMerging(nn.Module):
124
124
def __init__ (self , dim : int , out_dim : int , act_layer : LayerType = nn .ReLU ):
125
125
super ().__init__ ()
126
126
hid_dim = int (dim * 4 )
127
- self .conv1 = Conv2d_BN (dim , hid_dim )
127
+ self .conv1 = Conv2dNorm (dim , hid_dim )
128
128
self .act1 = act_layer ()
129
- self .conv2 = Conv2d_BN (hid_dim , hid_dim , 3 , 2 , 1 , groups = hid_dim )
129
+ self .conv2 = Conv2dNorm (hid_dim , hid_dim , 3 , 2 , 1 , groups = hid_dim )
130
130
self .act2 = act_layer ()
131
131
self .se = SqueezeExcite (hid_dim , 0.25 )
132
- self .conv3 = Conv2d_BN (hid_dim , out_dim )
132
+ self .conv3 = Conv2dNorm (hid_dim , out_dim )
133
133
134
134
def forward (self , x : torch .Tensor ) -> torch .Tensor :
135
135
x = self .conv1 (x )
@@ -144,9 +144,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
144
class FFN (nn .Module ):
145
145
def __init__ (self , dim : int , embed_dim : int , act_layer : LayerType = nn .ReLU ):
146
146
super ().__init__ ()
147
- self .pw1 = Conv2d_BN (dim , embed_dim )
147
+ self .pw1 = Conv2dNorm (dim , embed_dim )
148
148
self .act = act_layer ()
149
- self .pw2 = Conv2d_BN (embed_dim , dim , bn_weight_init = 0 )
149
+ self .pw2 = Conv2dNorm (embed_dim , dim , bn_weight_init = 0 )
150
150
151
151
def forward (self , x : torch .Tensor ) -> torch .Tensor :
152
152
x = self .pw1 (x )
@@ -173,8 +173,8 @@ def __init__(
173
173
174
174
self .pre_norm = norm_layer (pdim )
175
175
176
- self .qkv = Conv2d_BN (pdim , qk_dim * 2 + pdim )
177
- self .proj = nn .Sequential (act_layer (), Conv2d_BN (dim , dim , bn_weight_init = 0 ))
176
+ self .qkv = Conv2dNorm (pdim , qk_dim * 2 + pdim )
177
+ self .proj = nn .Sequential (act_layer (), Conv2dNorm (dim , dim , bn_weight_init = 0 ))
178
178
179
179
def forward (self , x : torch .Tensor ) -> torch .Tensor :
180
180
B , _ , H , W = x .shape
@@ -202,12 +202,12 @@ def __init__(
202
202
act_layer : LayerType = nn .ReLU ,
203
203
):
204
204
super ().__init__ ()
205
- self .conv = Residule ( Conv2d_BN (dim , dim , 3 , 1 , 1 , groups = dim , bn_weight_init = 0 ))
205
+ self .conv = Residual ( Conv2dNorm (dim , dim , 3 , 1 , 1 , groups = dim , bn_weight_init = 0 ))
206
206
if type == "s" :
207
- self .mixer = Residule (SHSA (dim , qk_dim , pdim , norm_layer , act_layer ))
207
+ self .mixer = Residual (SHSA (dim , qk_dim , pdim , norm_layer , act_layer ))
208
208
else :
209
209
self .mixer = nn .Identity ()
210
- self .ffn = Residule (FFN (dim , int (dim * 2 )))
210
+ self .ffn = Residual (FFN (dim , int (dim * 2 )))
211
211
212
212
def forward (self , x : torch .Tensor ) -> torch .Tensor :
213
213
x = self .conv (x )
@@ -231,11 +231,11 @@ def __init__(
231
231
super ().__init__ ()
232
232
self .grad_checkpointing = False
233
233
self .downsample = nn .Sequential (
234
- Residule ( Conv2d_BN (prev_dim , prev_dim , 3 , 1 , 1 , groups = prev_dim )),
235
- Residule (FFN (prev_dim , int (prev_dim * 2 ), act_layer )),
234
+ Residual ( Conv2dNorm (prev_dim , prev_dim , 3 , 1 , 1 , groups = prev_dim )),
235
+ Residual (FFN (prev_dim , int (prev_dim * 2 ), act_layer )),
236
236
PatchMerging (prev_dim , dim , act_layer ),
237
- Residule ( Conv2d_BN (dim , dim , 3 , 1 , 1 , groups = dim )),
238
- Residule (FFN (dim , int (dim * 2 ), act_layer )),
237
+ Residual ( Conv2dNorm (dim , dim , 3 , 1 , 1 , groups = dim )),
238
+ Residual (FFN (dim , int (dim * 2 ), act_layer )),
239
239
) if prev_dim != dim else nn .Identity ()
240
240
241
241
self .blocks = nn .Sequential (* [
@@ -274,13 +274,13 @@ def __init__(
274
274
# Patch embedding
275
275
stem_chs = embed_dim [0 ]
276
276
self .patch_embed = nn .Sequential (
277
- Conv2d_BN (in_chans , stem_chs // 8 , 3 , 2 , 1 ),
277
+ Conv2dNorm (in_chans , stem_chs // 8 , 3 , 2 , 1 ),
278
278
act_layer (),
279
- Conv2d_BN (stem_chs // 8 , stem_chs // 4 , 3 , 2 , 1 ),
279
+ Conv2dNorm (stem_chs // 8 , stem_chs // 4 , 3 , 2 , 1 ),
280
280
act_layer (),
281
- Conv2d_BN (stem_chs // 4 , stem_chs // 2 , 3 , 2 , 1 ),
281
+ Conv2dNorm (stem_chs // 4 , stem_chs // 2 , 3 , 2 , 1 ),
282
282
act_layer (),
283
- Conv2d_BN (stem_chs // 2 , stem_chs , 3 , 2 , 1 )
283
+ Conv2dNorm (stem_chs // 2 , stem_chs , 3 , 2 , 1 )
284
284
)
285
285
286
286
# Build SHViT blocks
@@ -305,7 +305,7 @@ def __init__(
305
305
self .num_features = self .head_hidden_size = embed_dim [- 1 ]
306
306
self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
307
307
self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
308
- self .head = BN_Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
308
+ self .head = NormLinear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
309
309
310
310
@torch .jit .ignore
311
311
def no_weight_decay (self ) -> Set :
@@ -336,7 +336,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
336
336
# cannot meaningfully change pooling of efficient head after creation
337
337
self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
338
338
self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
339
- self .head = BN_Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
339
+ self .head = NormLinear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
340
340
341
341
def forward_intermediates (
342
342
self ,
@@ -426,36 +426,36 @@ def fuse_children(net):
426
426
427
427
428
428
def checkpoint_filter_fn (state_dict : Dict [str , torch .Tensor ], model : nn .Module ) -> Dict [str , torch .Tensor ]:
429
- if 'model' in state_dict :
430
- state_dict = state_dict [ 'model' ]
431
- out_dict = {}
432
-
433
- replace_rules = [
434
- (re .compile (r'^blocks1\.' ), 'stages.0.blocks.' ),
435
- (re .compile (r'^blocks2\.' ), 'stages.1.blocks.' ),
436
- (re .compile (r'^blocks3\.' ), 'stages.2.blocks.' ),
437
- ]
438
- downsample_mapping = {}
439
- for i in range (1 , 3 ):
440
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .0\\ .0\\ .' ] = f'stages.{ i } .downsample.0.'
441
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .0\\ .1\\ .' ] = f'stages.{ i } .downsample.1.'
442
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .1\\ .' ] = f'stages.{ i } .downsample.2.'
443
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .2\\ .0\\ .' ] = f'stages.{ i } .downsample.3.'
444
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .2\\ .1\\ .' ] = f'stages.{ i } .downsample.4.'
445
- for j in range (3 , 10 ):
446
- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .{ j } \\ .' ] = f'stages.{ i } .blocks.{ j - 3 } .'
447
-
448
- downsample_patterns = [
449
- (re .compile (pattern ), replacement ) for pattern , replacement in downsample_mapping .items ()]
450
-
451
- for k , v in state_dict .items ():
452
- for pattern , replacement in replace_rules :
453
- k = pattern .sub (replacement , k )
454
- for pattern , replacement in downsample_patterns :
455
- k = pattern .sub (replacement , k )
456
- out_dict [k ] = v
457
-
458
- return out_dict
429
+ state_dict = state_dict . get ( 'model' , state_dict )
430
+
431
+ # out_dict = {}
432
+ #
433
+ # replace_rules = [
434
+ # (re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
435
+ # (re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
436
+ # (re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
437
+ # ]
438
+ # downsample_mapping = {}
439
+ # for i in range(1, 3):
440
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
441
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
442
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
443
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
444
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
445
+ # for j in range(3, 10):
446
+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
447
+ #
448
+ # downsample_patterns = [
449
+ # (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
450
+ #
451
+ # for k, v in state_dict.items():
452
+ # for pattern, replacement in replace_rules:
453
+ # k = pattern.sub(replacement, k)
454
+ # for pattern, replacement in downsample_patterns:
455
+ # k = pattern.sub(replacement, k)
456
+ # out_dict[k] = v
457
+
458
+ return state_dict
459
459
460
460
461
461
def _cfg (url : str = '' , ** kwargs : Any ) -> Dict [str , Any ]:
@@ -473,20 +473,20 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
473
473
474
474
default_cfgs = generate_default_cfgs ({
475
475
'shvit_s1.in1k' : _cfg (
476
- # hf_hub_id='timm/',
477
- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth' ,
476
+ hf_hub_id = 'timm/' ,
477
+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth',
478
478
),
479
479
'shvit_s2.in1k' : _cfg (
480
- # hf_hub_id='timm/',
481
- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth' ,
480
+ hf_hub_id = 'timm/' ,
481
+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth',
482
482
),
483
483
'shvit_s3.in1k' : _cfg (
484
- # hf_hub_id='timm/',
485
- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth' ,
484
+ hf_hub_id = 'timm/' ,
485
+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth',
486
486
),
487
487
'shvit_s4.in1k' : _cfg (
488
- # hf_hub_id='timm/',
489
- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth' ,
488
+ hf_hub_id = 'timm/' ,
489
+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth',
490
490
input_size = (3 , 256 , 256 ),
491
491
),
492
492
})
0 commit comments