Skip to content

Commit f902bcd

Browse files
committed
Layer refactoring continues, ResNet downsample rewrite for proper dilation in 3x3 and avg_pool cases
* select_conv2d -> create_conv2d * added create_attn to create attention module from string/bool/module * factor padding helpers into own file, use in both conv2d_same and avg_pool2d_same * add some more test eca resnet variants * minor tweaks, naming, comments, consistency
1 parent a99ec4e commit f902bcd

20 files changed

+311
-163
lines changed

timm/models/efficientnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .registry import register_model
2929
from .helpers import load_pretrained
3030
from .layers import SelectAdaptivePool2d
31-
from timm.models.layers import select_conv2d
31+
from timm.models.layers import create_conv2d
3232
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3333

3434

@@ -220,7 +220,7 @@ class EfficientNet(nn.Module):
220220

221221
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
222222
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
223-
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
223+
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
224224
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
225225
super(EfficientNet, self).__init__()
226226
norm_kwargs = norm_kwargs or {}
@@ -232,21 +232,21 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
232232

233233
# Stem
234234
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
235-
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
235+
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
236236
self.bn1 = norm_layer(stem_size, **norm_kwargs)
237237
self.act1 = act_layer(inplace=True)
238238
self._in_chs = stem_size
239239

240240
# Middle stages (IR/ER/DS Blocks)
241241
builder = EfficientNetBuilder(
242-
channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs,
242+
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
243243
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
244244
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
245245
self.feature_info = builder.features
246246
self._in_chs = builder.in_chs
247247

248248
# Head + Pooling
249-
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
249+
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
250250
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
251251
self.act2 = act_layer(inplace=True)
252252
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@@ -314,7 +314,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
314314

315315
# Stem
316316
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
317-
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
317+
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
318318
self.bn1 = norm_layer(stem_size, **norm_kwargs)
319319
self.act1 = act_layer(inplace=True)
320320
self._in_chs = stem_size

timm/models/efficientnet_blocks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
from torch.nn import functional as F
44
from .layers.activations import sigmoid
5-
from .layers import select_conv2d
5+
from .layers import create_conv2d
66

77

88
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@@ -129,7 +129,7 @@ def __init__(self, in_chs, out_chs, kernel_size,
129129
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
130130
super(ConvBnAct, self).__init__()
131131
norm_kwargs = norm_kwargs or {}
132-
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
132+
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
133133
self.bn1 = norm_layer(out_chs, **norm_kwargs)
134134
self.act1 = act_layer(inplace=True)
135135

@@ -162,7 +162,7 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
162162
self.has_pw_act = pw_act # activation after point-wise conv
163163
self.drop_connect_rate = drop_connect_rate
164164

165-
self.conv_dw = select_conv2d(
165+
self.conv_dw = create_conv2d(
166166
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
167167
self.bn1 = norm_layer(in_chs, **norm_kwargs)
168168
self.act1 = act_layer(inplace=True)
@@ -174,7 +174,7 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
174174
else:
175175
self.se = None
176176

177-
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
177+
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
178178
self.bn2 = norm_layer(out_chs, **norm_kwargs)
179179
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
180180

@@ -223,12 +223,12 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
223223
self.drop_connect_rate = drop_connect_rate
224224

225225
# Point-wise expansion
226-
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
226+
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
227227
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
228228
self.act1 = act_layer(inplace=True)
229229

230230
# Depth-wise convolution
231-
self.conv_dw = select_conv2d(
231+
self.conv_dw = create_conv2d(
232232
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
233233
padding=pad_type, depthwise=True, **conv_kwargs)
234234
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
@@ -242,7 +242,7 @@ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
242242
self.se = None
243243

244244
# Point-wise linear projection
245-
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
245+
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
246246
self.bn3 = norm_layer(out_chs, **norm_kwargs)
247247

248248
def feature_module(self, location):
@@ -356,7 +356,7 @@ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_ch
356356
self.drop_connect_rate = drop_connect_rate
357357

358358
# Expansion convolution
359-
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
359+
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
360360
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
361361
self.act1 = act_layer(inplace=True)
362362

@@ -368,7 +368,7 @@ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_ch
368368
self.se = None
369369

370370
# Point-wise linear projection
371-
self.conv_pwl = select_conv2d(
371+
self.conv_pwl = create_conv2d(
372372
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
373373
self.bn2 = norm_layer(out_chs, **norm_kwargs)
374374

timm/models/gluon_resnet.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .registry import register_model
1313
from .helpers import load_pretrained
14+
from .layers import SEModule
1415
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1516

1617
from .resnet import ResNet, Bottleneck, BasicBlock
@@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
319320
"""
320321
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
321322
model = ResNet(
322-
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True,
323-
num_classes=num_classes, in_chans=in_chans, **kwargs)
323+
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
324+
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
324325
model.default_cfg = default_cfg
325326
if pretrained:
326327
load_pretrained(model, default_cfg, num_classes, in_chans)
@@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
333334
"""
334335
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
335336
model = ResNet(
336-
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True,
337-
num_classes=num_classes, in_chans=in_chans, **kwargs)
337+
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
338+
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
338339
model.default_cfg = default_cfg
339340
if pretrained:
340341
load_pretrained(model, default_cfg, num_classes, in_chans)
@@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
346347
"""Constructs a SEResNeXt-101-64x4d model.
347348
"""
348349
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
350+
block_args = dict(attn_layer=SEModule)
349351
model = ResNet(
350-
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True,
351-
num_classes=num_classes, in_chans=in_chans, **kwargs)
352+
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
353+
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
352354
model.default_cfg = default_cfg
353355
if pretrained:
354356
load_pretrained(model, default_cfg, num_classes, in_chans)
@@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
360362
"""Constructs an SENet-154 model.
361363
"""
362364
default_cfg = default_cfgs['gluon_senet154']
365+
block_args = dict(attn_layer=SEModule)
363366
model = ResNet(
364-
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True,
365-
stem_type='deep', down_kernel_size=3, block_reduce_first=2,
366-
num_classes=num_classes, in_chans=in_chans, **kwargs)
367+
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
368+
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
367369
model.default_cfg = default_cfg
368370
if pretrained:
369371
load_pretrained(model, default_cfg, num_classes, in_chans)

timm/models/layers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from .padding import get_padding
2+
from .avg_pool2d_same import AvgPool2dSame
3+
from .conv2d_same import Conv2dSame
14
from .conv_bn_act import ConvBnAct
25
from .mixed_conv2d import MixedConv2d
36
from .cond_conv2d import CondConv2d, get_condconv_initializer
4-
from .select_conv2d import select_conv2d
7+
from .create_conv2d import create_conv2d
8+
from .create_attn import create_attn
59
from .selective_kernel import SelectiveKernelConv
10+
from .se import SEModule
611
from .eca import EcaModule, CecaModule
712
from .activations import *
813
from .adaptive_avgmax_pool import \

timm/models/layers/avg_pool2d_same.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
""" AvgPool2d w/ Same Padding
2+
3+
Hacked together by Ross Wightman
4+
"""
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from typing import List
9+
import math
10+
11+
from .helpers import tup_pair
12+
from .padding import pad_same
13+
14+
15+
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
16+
ceil_mode: bool = False, count_include_pad: bool = True):
17+
x = pad_same(x, kernel_size, stride)
18+
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19+
20+
21+
class AvgPool2dSame(nn.AvgPool2d):
22+
""" Tensorflow like 'SAME' wrapper for 2D average pooling
23+
"""
24+
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25+
kernel_size = tup_pair(kernel_size)
26+
stride = tup_pair(stride)
27+
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28+
29+
def forward(self, x):
30+
return avg_pool2d_same(
31+
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)

timm/models/layers/cond_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch import nn as nn
1111
from torch.nn import functional as F
1212

13+
from .helpers import tup_pair
1314
from .conv2d_same import get_padding_value, conv2d_same
14-
from .conv_helpers import tup_pair
1515

1616

1717
def get_condconv_initializer(initializer, num_experts, expert_shape):

timm/models/layers/conv2d_same.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,13 @@
88
from typing import Union, List, Tuple, Optional, Callable
99
import math
1010

11-
from .conv_helpers import get_padding
12-
13-
14-
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
15-
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
16-
17-
18-
def _calc_same_pad(i: int, k: int, s: int, d: int):
19-
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
11+
from .padding import get_padding, pad_same, is_static_pad
2012

2113

2214
def conv2d_same(
2315
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
2416
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
25-
ih, iw = x.size()[-2:]
26-
kh, kw = weight.size()[-2:]
27-
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
28-
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
29-
if pad_h > 0 or pad_w > 0:
30-
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
17+
x = pad_same(x, weight.shape[-2:], stride, dilation)
3118
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
3219

3320

@@ -51,7 +38,7 @@ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
5138
padding = padding.lower()
5239
if padding == 'same':
5340
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
54-
if _is_static_pad(kernel_size, **kwargs):
41+
if is_static_pad(kernel_size, **kwargs):
5542
# static case, no extra overhead
5643
padding = get_padding(kernel_size, **kwargs)
5744
else:

timm/models/layers/conv_bn_act.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
from torch import nn as nn
66

7-
from timm.models.layers.conv_helpers import get_padding
7+
from timm.models.layers import get_padding
88

99

1010
class ConvBnAct(nn.Module):

timm/models/layers/create_attn.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
""" Select AttentionFactory Method
2+
3+
Hacked together by Ross Wightman
4+
"""
5+
import torch
6+
from .se import SEModule
7+
from .eca import EcaModule, CecaModule
8+
9+
10+
def create_attn(attn_type, channels, **kwargs):
11+
module_cls = None
12+
if attn_type is not None:
13+
if isinstance(attn_type, str):
14+
attn_type = attn_type.lower()
15+
if attn_type == 'se':
16+
module_cls = SEModule
17+
elif attn_type == 'eca':
18+
module_cls = EcaModule
19+
elif attn_type == 'eca':
20+
module_cls = CecaModule
21+
else:
22+
assert False, "Invalid attn module (%s)" % attn_type
23+
elif isinstance(attn_type, bool):
24+
if attn_type:
25+
module_cls = SEModule
26+
else:
27+
module_cls = attn_type
28+
if module_cls is not None:
29+
return module_cls(channels, **kwargs)
30+
return None

timm/models/layers/select_conv2d.py renamed to timm/models/layers/create_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Select Conv2d Factory Method
1+
""" Create Conv2d Factory Method
22
33
Hacked together by Ross Wightman
44
"""
@@ -8,7 +8,7 @@
88
from .conv2d_same import create_conv2d_pad
99

1010

11-
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
11+
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
1212
""" Select a 2d convolution implementation based on arguments
1313
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
1414

0 commit comments

Comments
 (0)