Skip to content

Commit e69b906

Browse files
committed
Update cvt.py
1 parent 186dab3 commit e69b906

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

timm/models/cvt.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from collections import OrderedDict
1212
from functools import partial
13-
from typing import List, Final, Optional, Tuple
13+
from typing import List, Final, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.nn as nn
@@ -379,17 +379,17 @@ def __init__(
379379
if self.cls_token is not None:
380380
trunc_normal_(self.cls_token, std=.02)
381381

382-
def forward(self, x: torch.Tensor) -> torch.Tensor:
382+
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
383383
x = self.conv_embed(x)
384384
x = self.embed_drop(x)
385385

386386
cls_token = self.embed_drop(
387387
self.cls_token.expand(x.shape[0], -1, -1)
388388
) if self.cls_token is not None else None
389-
for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
389+
for block in self.blocks: # TODO technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tuple
390390
x, cls_token = block(x, cls_token)
391391

392-
return x, cls_token
392+
return (x, cls_token) if self.cls_token is not None else x
393393

394394
class CvT(nn.Module):
395395
def __init__(
@@ -429,8 +429,8 @@ def __init__(
429429
assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token)
430430
self.num_classes = num_classes
431431
self.num_features = dims[-1]
432+
self.feature_info = []
432433

433-
# FIXME only on last stage, no need for tuple
434434
self.use_cls_token = use_cls_token[-1]
435435

436436
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@@ -473,28 +473,27 @@ def __init__(
473473
)
474474
in_chs = dim
475475
stages.append(stage)
476-
self.stages = nn.ModuleList(stages)
476+
self.feature_info += [dict(num_chs=dim, reduction=2, module=f'stages.{stage_idx}')]
477+
self.stages = nn.Sequential(*stages)
477478

478479
self.norm = norm_layer(dims[-1])
479480
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
480481

481482
def forward(self, x: torch.Tensor) -> torch.Tensor:
482483

483484
for stage in self.stages:
484-
x, cls_token = stage(x)
485+
x = stage(x)
485486

486487

487488
if self.use_cls_token:
488-
return self.head(self.norm(cls_token.flatten(1)))
489+
return self.head(self.norm(x[1].flatten(1)))
489490
else:
490491
return self.head(self.norm(x.mean(dim=(2,3))))
491492

492493

493494

494495
def checkpoint_filter_fn(state_dict, model):
495496
""" Remap MSFT checkpoints -> timm """
496-
if 'head.fc.weight' in state_dict:
497-
return state_dict # non-MSFT checkpoint
498497

499498
if 'state_dict' in state_dict:
500499
state_dict = state_dict['state_dict']
@@ -524,14 +523,13 @@ def _create_cvt(variant, pretrained=False, **kwargs):
524523

525524
return model
526525

527-
# TODO update first_conv
528526
def _cfg(url='', **kwargs):
529527
return {
530528
'url': url,
531529
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14),
532530
'crop_pct': 0.95, 'interpolation': 'bicubic',
533531
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
534-
'first_conv': 'stem.conv', 'classifier': 'head',
532+
'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head',
535533
**kwargs
536534
}
537535

0 commit comments

Comments
 (0)