Skip to content

Commit 5247eb3

Browse files
authored
Merge pull request #233 from rwightman/torchamp
Native Torch AMP and channels_last support for train.py and validate.py
2 parents 6d158ad + 751b0bb commit 5247eb3

16 files changed

+316
-226
lines changed

tests/test_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def test_model_load_pretrained(model_name, batch_size):
120120
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
121121
create_model(model_name, pretrained=True, in_chans=in_chans)
122122

123+
@pytest.mark.timeout(120)
124+
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
125+
@pytest.mark.parametrize('batch_size', [1])
126+
def test_model_features_pretrained(model_name, batch_size):
127+
"""Create that pretrained weights load when features_only==True."""
128+
create_model(model_name, pretrained=True, features_only=True)
123129

124130
EXCLUDE_JIT_FILTERS = [
125131
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable

timm/models/efficientnet_blocks.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module):
106106
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
107107
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
108108
super(SqueezeExcite, self).__init__()
109-
self.gate_fn = gate_fn
110109
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
111-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
112110
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
113111
self.act1 = act_layer(inplace=True)
114112
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
113+
self.gate_fn = gate_fn
115114

116115
def forward(self, x):
117-
x_se = self.avg_pool(x)
116+
x_se = x.mean((2, 3), keepdim=True)
118117
x_se = self.conv_reduce(x_se)
119118
x_se = self.act1(x_se)
120119
x_se = self.conv_expand(x_se)
121-
x = x * self.gate_fn(x_se)
122-
return x
120+
return x * self.gate_fn(x_se)
123121

124122

125123
class ConvBnAct(nn.Module):

timm/models/factory.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,17 @@ def create_model(
3939
kwargs.pop('bn_momentum', None)
4040
kwargs.pop('bn_eps', None)
4141

42-
# Parameters that aren't supported by all models should default to None in command line args,
43-
# remove them if they are present and not set so that non-supporting models don't break.
44-
if kwargs.get('drop_block_rate', None) is None:
45-
kwargs.pop('drop_block_rate', None)
46-
4742
# handle backwards compat with drop_connect -> drop_path change
4843
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
4944
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
5045
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
5146
" Setting drop_path to %f." % drop_connect_rate)
5247
kwargs['drop_path_rate'] = drop_connect_rate
5348

54-
if kwargs.get('drop_path_rate', None) is None:
55-
kwargs.pop('drop_path_rate', None)
49+
# Parameters that aren't supported by all models or are intended to only override model defaults if set
50+
# should default to None in command line args/cfg. Remove them if they are present and not set so that
51+
# non-supporting models don't break and default args remain in effect.
52+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
5653

5754
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
5855
if is_model(model_name):

timm/models/helpers.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
4848
model.load_state_dict(state_dict, strict=strict)
4949

5050

51-
def resume_checkpoint(model, checkpoint_path):
52-
other_state = {}
51+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
5352
resume_epoch = None
5453
if os.path.isfile(checkpoint_path):
5554
checkpoint = torch.load(checkpoint_path, map_location='cpu')
5655
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
56+
if log_info:
57+
_logger.info('Restoring model state from checkpoint...')
5758
new_state_dict = OrderedDict()
5859
for k, v in checkpoint['state_dict'].items():
5960
name = k[7:] if k.startswith('module') else k
6061
new_state_dict[name] = v
6162
model.load_state_dict(new_state_dict)
62-
if 'optimizer' in checkpoint:
63-
other_state['optimizer'] = checkpoint['optimizer']
64-
if 'amp' in checkpoint:
65-
other_state['amp'] = checkpoint['amp']
63+
64+
if optimizer is not None and 'optimizer' in checkpoint:
65+
if log_info:
66+
_logger.info('Restoring optimizer state from checkpoint...')
67+
optimizer.load_state_dict(checkpoint['optimizer'])
68+
69+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
70+
if log_info:
71+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
72+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
73+
6674
if 'epoch' in checkpoint:
6775
resume_epoch = checkpoint['epoch']
6876
if 'version' in checkpoint and checkpoint['version'] > 1:
6977
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
70-
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
78+
79+
if log_info:
80+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
7181
else:
7282
model.load_state_dict(checkpoint)
73-
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
74-
return other_state, resume_epoch
83+
if log_info:
84+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
85+
return resume_epoch
7586
else:
7687
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
7788
raise FileNotFoundError()

timm/models/hrnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,12 +773,14 @@ def forward(self, x) -> List[torch.tensor]:
773773

774774
def _create_hrnet(variant, pretrained, **model_kwargs):
775775
model_cls = HighResolutionNet
776+
strict = True
776777
if model_kwargs.pop('features_only', False):
777778
model_cls = HighResolutionNetFeatures
779+
strict = False
778780

779781
return build_model_with_cfg(
780782
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
781-
model_cfg=cfg_cls[variant], **model_kwargs)
783+
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
782784

783785

784786
@register_model

timm/models/layers/adaptive_avgmax_pool.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
4949
return x
5050

5151

52+
class FastAdaptiveAvgPool2d(nn.Module):
53+
def __init__(self, flatten=False):
54+
super(FastAdaptiveAvgPool2d, self).__init__()
55+
self.flatten = flatten
56+
57+
def forward(self, x):
58+
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
59+
60+
5261
class AdaptiveAvgMaxPool2d(nn.Module):
5362
def __init__(self, output_size=1):
5463
super(AdaptiveAvgMaxPool2d, self).__init__()
@@ -70,12 +79,16 @@ def forward(self, x):
7079
class SelectAdaptivePool2d(nn.Module):
7180
"""Selectable global pooling layer with dynamic input kernel size
7281
"""
73-
def __init__(self, output_size=1, pool_type='avg', flatten=False):
82+
def __init__(self, output_size=1, pool_type='fast', flatten=False):
7483
super(SelectAdaptivePool2d, self).__init__()
7584
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
7685
self.flatten = flatten
7786
if pool_type == '':
7887
self.pool = nn.Identity() # pass through
88+
elif pool_type == 'fast':
89+
assert output_size == 1
90+
self.pool = FastAdaptiveAvgPool2d(self.flatten)
91+
self.flatten = False
7992
elif pool_type == 'avg':
8093
self.pool = nn.AdaptiveAvgPool2d(output_size)
8194
elif pool_type == 'avgmax':

timm/models/layers/cbam.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch import nn as nn
13+
import torch.nn.functional as F
1314
from .conv_bn_act import ConvBnAct
1415

1516

@@ -18,15 +19,13 @@ class ChannelAttn(nn.Module):
1819
"""
1920
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
2021
super(ChannelAttn, self).__init__()
21-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
22-
self.max_pool = nn.AdaptiveMaxPool2d(1)
2322
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
2423
self.act = act_layer(inplace=True)
2524
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
2625

2726
def forward(self, x):
28-
x_avg = self.avg_pool(x)
29-
x_max = self.max_pool(x)
27+
x_avg = x.mean((2, 3), keepdim=True)
28+
x_max = F.adaptive_max_pool2d(x, 1)
3029
x_avg = self.fc2(self.act(self.fc1(x_avg)))
3130
x_max = self.fc2(self.act(self.fc1(x_max)))
3231
x_attn = x_avg + x_max
@@ -40,7 +39,7 @@ def __init__(self, channels, reduction=16):
4039
super(LightChannelAttn, self).__init__(channels, reduction)
4140

4241
def forward(self, x):
43-
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
42+
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
4443
x_attn = self.fc2(self.act(self.fc1(x_pool)))
4544
return x * x_attn.sigmoid()
4645

timm/models/layers/eca.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,15 @@ class EcaModule(nn.Module):
5252
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
5353
super(EcaModule, self).__init__()
5454
assert kernel_size % 2 == 1
55-
5655
if channels is not None:
5756
t = int(abs(math.log(channels, 2) + beta) / gamma)
5857
kernel_size = max(t if t % 2 else t + 1, 3)
5958

60-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
6159
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
6260

6361
def forward(self, x):
64-
# Feature descriptor on the global spatial information
65-
y = self.avg_pool(x)
66-
# Reshape for convolution
67-
y = y.view(x.shape[0], 1, -1)
68-
# Two different branches of ECA module
62+
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
6963
y = self.conv(y)
70-
# Multi-scale information fusion
7164
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
7265
return x * y.expand_as(x)
7366

@@ -95,30 +88,20 @@ class CecaModule(nn.Module):
9588
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
9689
super(CecaModule, self).__init__()
9790
assert kernel_size % 2 == 1
98-
9991
if channels is not None:
10092
t = int(abs(math.log(channels, 2) + beta) / gamma)
10193
kernel_size = max(t if t % 2 else t + 1, 3)
10294

103-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
104-
#pytorch circular padding mode is buggy as of pytorch 1.4
105-
#see https://github.com/pytorch/pytorch/pull/17240
106-
107-
#implement manual circular padding
95+
# PyTorch circular padding mode is buggy as of pytorch 1.4
96+
# see https://github.com/pytorch/pytorch/pull/17240
97+
# implement manual circular padding
10898
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
10999
self.padding = (kernel_size - 1) // 2
110100

111101
def forward(self, x):
112-
# Feature descriptor on the global spatial information
113-
y = self.avg_pool(x)
114-
102+
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
115103
# Manually implement circular padding, F.pad does not seemed to be bugged
116-
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
117-
118-
# Two different branches of ECA module
104+
y = F.pad(y, (self.padding, self.padding), mode='circular')
119105
y = self.conv(y)
120-
121-
# Multi-scale information fusion
122106
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
123-
124107
return x * y.expand_as(x)

timm/models/layers/se.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
11
from torch import nn as nn
2-
from .create_act import get_act_fn
2+
from .create_act import create_act_layer
33

44

55
class SEModule(nn.Module):
66

77
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
8-
gate_fn='sigmoid'):
8+
gate_layer='sigmoid'):
99
super(SEModule, self).__init__()
10-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
1110
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
12-
self.fc1 = nn.Conv2d(
13-
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
11+
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
1412
self.act = act_layer(inplace=True)
15-
self.fc2 = nn.Conv2d(
16-
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
17-
self.gate_fn = get_act_fn(gate_fn)
13+
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
14+
self.gate = create_act_layer(gate_layer)
1815

1916
def forward(self, x):
20-
x_se = self.avg_pool(x)
17+
x_se = x.mean((2, 3), keepdim=True)
2118
x_se = self.fc1(x_se)
2219
x_se = self.act(x_se)
2320
x_se = self.fc2(x_se)
24-
return x * self.gate_fn(x_se)
21+
return x * self.gate(x_se)
2522

2623

2724
class EffectiveSEModule(nn.Module):
2825
""" 'Effective Squeeze-Excitation
2926
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
3027
"""
31-
def __init__(self, channels, gate_fn='hard_sigmoid'):
28+
def __init__(self, channels, gate_layer='hard_sigmoid'):
3229
super(EffectiveSEModule, self).__init__()
33-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
3430
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
35-
self.gate_fn = get_act_fn(gate_fn)
31+
self.gate = create_act_layer(gate_layer, inplace=True)
3632

3733
def forward(self, x):
38-
x_se = self.avg_pool(x)
34+
x_se = x.mean((2, 3), keepdim=True)
3935
x_se = self.fc(x_se)
40-
return x * self.gate_fn(x_se, inplace=True)
36+
return x * self.gate(x_se)

timm/models/layers/selective_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ def __init__(self, channels, num_paths=2, attn_channels=32,
2727
"""
2828
super(SelectiveKernelAttn, self).__init__()
2929
self.num_paths = num_paths
30-
self.pool = nn.AdaptiveAvgPool2d(1)
3130
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
3231
self.bn = norm_layer(attn_channels)
3332
self.act = act_layer(inplace=True)
3433
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
3534

3635
def forward(self, x):
3736
assert x.shape[1] == self.num_paths
38-
x = torch.sum(x, dim=1)
39-
x = self.pool(x)
37+
x = x.sum(1).mean((2, 3), keepdim=True)
4038
x = self.fc_reduce(x)
4139
x = self.bn(x)
4240
x = self.act(x)

0 commit comments

Comments
 (0)