Skip to content

Commit efa1a36

Browse files
authored
Cvt 1 (#14)
* Update cvt.py * Update cvt.py
1 parent 7ba93ae commit efa1a36

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

timm/models/cvt.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.dim = dim
7171

7272
# FIXME not working, bn layer outputs are incorrect
73-
'''
73+
7474
self.conv_q = ConvNormAct(
7575
dim,
7676
dim,
@@ -143,7 +143,8 @@ def __init__(
143143
groups=dim
144144
)),
145145
('bn', nn.BatchNorm2d(dim)),]))
146-
146+
'''
147+
147148
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148149
B, C, H, W = x.shape
149150
# [B, C, H, W] -> [B, H*W, C]
@@ -170,7 +171,7 @@ def __init__(
170171
self.num_heads = num_heads
171172
self.head_dim = dim // num_heads
172173
self.scale = dim ** -0.5
173-
self.fused_attn = False #use_fused_attn()
174+
self.fused_attn = use_fused_attn()
174175

175176
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
176177
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
@@ -534,11 +535,36 @@ def _cfg(url='', **kwargs):
534535
}
535536

536537
default_cfgs = generate_default_cfgs({
537-
'cvt_13.msft_in1k': _cfg(url='https://files.catbox.moe/xz97kh.pth'),
538+
'cvt_13.msft_in1k': _cfg(
539+
url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-224x224-IN-1k.pth'),
540+
'cvt_13.msft_in1k_384': _cfg(
541+
url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-1k.pth',
542+
input_size=(3, 384, 384), pool_size=(24, 24)),
543+
'cvt_13.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-22k.pth',
544+
input_size=(3, 384, 384), pool_size=(24, 24)),
545+
546+
'cvt_21.msft_in1k': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-224x224-IN-1k.pth'),
547+
'cvt_21.msft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-1k.pth',
548+
input_size=(3, 384, 384), pool_size=(24, 24)),
549+
'cvt_21.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-22k.pth',
550+
input_size=(3, 384, 384), pool_size=(24, 24)),
551+
552+
'cvt_w24.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-w24-384x384-IN-22k.pth',
553+
input_size=(3, 384, 384), pool_size=(24, 24)),
538554
})
539555

540556

541557
@register_model
542558
def cvt_13(pretrained=False, **kwargs) -> CvT:
543559
model_args = dict(depths=(1, 2, 10), dims=(64, 192, 384), num_heads=(1, 3, 6))
544560
return _create_cvt('cvt_13', pretrained=pretrained, **dict(model_args, **kwargs))
561+
562+
@register_model
563+
def cvt_21(pretrained=False, **kwargs) -> CvT:
564+
model_args = dict(depths=(1, 4, 16), dims=(64, 192, 384), num_heads=(1, 3, 6))
565+
return _create_cvt('cvt_21', pretrained=pretrained, **dict(model_args, **kwargs))
566+
567+
@register_model
568+
def cvt_w24(pretrained=False, **kwargs) -> CvT:
569+
model_args = dict(depths=(2, 2, 20), dims=(192, 768, 1024), num_heads=(3, 12, 16))
570+
return _create_cvt('cvt_w24', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)