@@ -70,7 +70,7 @@ def __init__(
70
70
self .dim = dim
71
71
72
72
# FIXME not working, bn layer outputs are incorrect
73
- '''
73
+
74
74
self .conv_q = ConvNormAct (
75
75
dim ,
76
76
dim ,
@@ -143,7 +143,8 @@ def __init__(
143
143
groups=dim
144
144
)),
145
145
('bn', nn.BatchNorm2d(dim)),]))
146
-
146
+ '''
147
+
147
148
def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
148
149
B , C , H , W = x .shape
149
150
# [B, C, H, W] -> [B, H*W, C]
@@ -170,7 +171,7 @@ def __init__(
170
171
self .num_heads = num_heads
171
172
self .head_dim = dim // num_heads
172
173
self .scale = dim ** - 0.5
173
- self .fused_attn = False # use_fused_attn()
174
+ self .fused_attn = use_fused_attn ()
174
175
175
176
self .proj_q = nn .Linear (dim , dim , bias = qkv_bias )
176
177
self .proj_k = nn .Linear (dim , dim , bias = qkv_bias )
@@ -534,11 +535,36 @@ def _cfg(url='', **kwargs):
534
535
}
535
536
536
537
default_cfgs = generate_default_cfgs ({
537
- 'cvt_13.msft_in1k' : _cfg (url = 'https://files.catbox.moe/xz97kh.pth' ),
538
+ 'cvt_13.msft_in1k' : _cfg (
539
+ url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-224x224-IN-1k.pth' ),
540
+ 'cvt_13.msft_in1k_384' : _cfg (
541
+ url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-1k.pth' ,
542
+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
543
+ 'cvt_13.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-22k.pth' ,
544
+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
545
+
546
+ 'cvt_21.msft_in1k' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-224x224-IN-1k.pth' ),
547
+ 'cvt_21.msft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-1k.pth' ,
548
+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
549
+ 'cvt_21.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-22k.pth' ,
550
+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
551
+
552
+ 'cvt_w24.msft_in22k_ft_in1k_384' : _cfg (url = 'https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-w24-384x384-IN-22k.pth' ,
553
+ input_size = (3 , 384 , 384 ), pool_size = (24 , 24 )),
538
554
})
539
555
540
556
541
557
@register_model
542
558
def cvt_13 (pretrained = False , ** kwargs ) -> CvT :
543
559
model_args = dict (depths = (1 , 2 , 10 ), dims = (64 , 192 , 384 ), num_heads = (1 , 3 , 6 ))
544
560
return _create_cvt ('cvt_13' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
561
+
562
+ @register_model
563
+ def cvt_21 (pretrained = False , ** kwargs ) -> CvT :
564
+ model_args = dict (depths = (1 , 4 , 16 ), dims = (64 , 192 , 384 ), num_heads = (1 , 3 , 6 ))
565
+ return _create_cvt ('cvt_21' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
566
+
567
+ @register_model
568
+ def cvt_w24 (pretrained = False , ** kwargs ) -> CvT :
569
+ model_args = dict (depths = (2 , 2 , 20 ), dims = (192 , 768 , 1024 ), num_heads = (3 , 12 , 16 ))
570
+ return _create_cvt ('cvt_w24' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments