|
10 | 10 |
|
11 | 11 | from collections import OrderedDict
|
12 | 12 | from functools import partial
|
13 |
| -from typing import List, Final, Optional, Tuple |
| 13 | +from typing import List, Final, Optional, Tuple, Union |
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
@@ -379,17 +379,17 @@ def __init__(
|
379 | 379 | if self.cls_token is not None:
|
380 | 380 | trunc_normal_(self.cls_token, std=.02)
|
381 | 381 |
|
382 |
| - def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 382 | + def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
383 | 383 | x = self.conv_embed(x)
|
384 | 384 | x = self.embed_drop(x)
|
385 | 385 |
|
386 | 386 | cls_token = self.embed_drop(
|
387 | 387 | self.cls_token.expand(x.shape[0], -1, -1)
|
388 | 388 | ) if self.cls_token is not None else None
|
389 |
| - for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor |
| 389 | + for block in self.blocks: # TODO technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tuple |
390 | 390 | x, cls_token = block(x, cls_token)
|
391 | 391 |
|
392 |
| - return x, cls_token |
| 392 | + return (x, cls_token) if self.cls_token is not None else x |
393 | 393 |
|
394 | 394 | class CvT(nn.Module):
|
395 | 395 | def __init__(
|
@@ -429,8 +429,8 @@ def __init__(
|
429 | 429 | assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token)
|
430 | 430 | self.num_classes = num_classes
|
431 | 431 | self.num_features = dims[-1]
|
| 432 | + self.feature_info = [] |
432 | 433 |
|
433 |
| - # FIXME only on last stage, no need for tuple |
434 | 434 | self.use_cls_token = use_cls_token[-1]
|
435 | 435 |
|
436 | 436 | dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
@@ -473,28 +473,27 @@ def __init__(
|
473 | 473 | )
|
474 | 474 | in_chs = dim
|
475 | 475 | stages.append(stage)
|
476 |
| - self.stages = nn.ModuleList(stages) |
| 476 | + self.feature_info += [dict(num_chs=dim, reduction=2, module=f'stages.{stage_idx}')] |
| 477 | + self.stages = nn.Sequential(*stages) |
477 | 478 |
|
478 | 479 | self.norm = norm_layer(dims[-1])
|
479 | 480 | self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
480 | 481 |
|
481 | 482 | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
482 | 483 |
|
483 | 484 | for stage in self.stages:
|
484 |
| - x, cls_token = stage(x) |
| 485 | + x = stage(x) |
485 | 486 |
|
486 | 487 |
|
487 | 488 | if self.use_cls_token:
|
488 |
| - return self.head(self.norm(cls_token.flatten(1))) |
| 489 | + return self.head(self.norm(x[1].flatten(1))) |
489 | 490 | else:
|
490 | 491 | return self.head(self.norm(x.mean(dim=(2,3))))
|
491 | 492 |
|
492 | 493 |
|
493 | 494 |
|
494 | 495 | def checkpoint_filter_fn(state_dict, model):
|
495 | 496 | """ Remap MSFT checkpoints -> timm """
|
496 |
| - if 'head.fc.weight' in state_dict: |
497 |
| - return state_dict # non-MSFT checkpoint |
498 | 497 |
|
499 | 498 | if 'state_dict' in state_dict:
|
500 | 499 | state_dict = state_dict['state_dict']
|
@@ -524,14 +523,13 @@ def _create_cvt(variant, pretrained=False, **kwargs):
|
524 | 523 |
|
525 | 524 | return model
|
526 | 525 |
|
527 |
| -# TODO update first_conv |
528 | 526 | def _cfg(url='', **kwargs):
|
529 | 527 | return {
|
530 | 528 | 'url': url,
|
531 | 529 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14),
|
532 | 530 | 'crop_pct': 0.95, 'interpolation': 'bicubic',
|
533 | 531 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
534 |
| - 'first_conv': 'stem.conv', 'classifier': 'head', |
| 532 | + 'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head', |
535 | 533 | **kwargs
|
536 | 534 | }
|
537 | 535 |
|
|
0 commit comments