Skip to content

Commit 3d9be78

Browse files
committed
A bit more ResNet cleanup.
* add inplace=True back * minor comment improvements * few clarity changes
1 parent 33436fa commit 3d9be78

File tree

1 file changed

+23
-30
lines changed

1 file changed

+23
-30
lines changed

timm/models/resnet.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,17 @@ def __init__(self, channels, reduction_channels):
7979
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
8080
self.fc1 = nn.Conv2d(
8181
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
82-
self.relu = nn.ReLU()
82+
self.relu = nn.ReLU(inplace=True)
8383
self.fc2 = nn.Conv2d(
8484
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
85-
self.sigmoid = nn.Sigmoid()
8685

8786
def forward(self, x):
88-
module_input = x
89-
#x = self.avg_pool(x)
90-
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
91-
x = self.fc1(x)
92-
x = self.relu(x)
93-
x = self.fc2(x)
94-
x = self.sigmoid(x)
95-
return module_input * x
87+
#x_se = self.avg_pool(x)
88+
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
89+
x_se = self.fc1(x_se)
90+
x_se = self.relu(x_se)
91+
x_se = self.fc2(x_se)
92+
return x * x_se.sigmoid()
9693

9794

9895
class BasicBlock(nn.Module):
@@ -112,7 +109,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
112109
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
113110
dilation=dilation, bias=False)
114111
self.bn1 = norm_layer(first_planes)
115-
self.relu = nn.ReLU()
112+
self.relu = nn.ReLU(inplace=True)
116113
self.conv2 = nn.Conv2d(
117114
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
118115
dilation=previous_dilation, bias=False)
@@ -164,7 +161,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
164161
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
165162
self.bn3 = norm_layer(outplanes)
166163
self.se = SEModule(outplanes, planes // 4) if use_se else None
167-
self.relu = nn.ReLU()
164+
self.relu = nn.ReLU(inplace=True)
168165
self.downsample = downsample
169166
self.stride = stride
170167
self.dilation = dilation
@@ -203,13 +200,14 @@ class ResNet(nn.Module):
203200
* have conv-bn-act ordering
204201
205202
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
206-
variants included in the MXNet Gluon ResNetV1b model
203+
variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
204+
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
207205
208206
ResNet variants:
209-
* normal - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
207+
* normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
210208
* c - 3 layer deep 3x3 stem, stem_width = 32
211209
* d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
212-
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available
210+
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample
213211
* s - 3 layer deep 3x3 stem, stem_width = 64
214212
215213
ResNeXt
@@ -275,31 +273,25 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
275273
self.conv1 = nn.Sequential(*[
276274
nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False),
277275
norm_layer(stem_width),
278-
nn.ReLU(),
276+
nn.ReLU(inplace=True),
279277
nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False),
280278
norm_layer(stem_width),
281-
nn.ReLU(),
279+
nn.ReLU(inplace=True),
282280
nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)])
283281
else:
284282
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False)
285283
self.bn1 = norm_layer(self.inplanes)
286-
self.relu = nn.ReLU()
284+
self.relu = nn.ReLU(inplace=True)
287285
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
288286
stride_3_4 = 1 if self.dilated else 2
289287
dilation_3 = 2 if self.dilated else 1
290288
dilation_4 = 4 if self.dilated else 1
291-
self.layer1 = self._make_layer(
292-
block, 64, layers[0], stride=1, reduce_first=block_reduce_first,
293-
use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer)
294-
self.layer2 = self._make_layer(
295-
block, 128, layers[1], stride=2, reduce_first=block_reduce_first,
296-
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
297-
self.layer3 = self._make_layer(
298-
block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first,
299-
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
300-
self.layer4 = self._make_layer(
301-
block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first,
302-
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
289+
largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
290+
avg_down=avg_down, down_kernel_size=down_kernel_size)
291+
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
292+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
293+
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
294+
self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs)
303295
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
304296
self.num_features = 512 * block.expansion
305297
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@@ -314,6 +306,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
314306
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
315307
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
316308
downsample = None
309+
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
317310
if stride != 1 or self.inplanes != planes * block.expansion:
318311
downsample_padding = _get_padding(down_kernel_size, stride)
319312
downsample_layers = []

0 commit comments

Comments
 (0)