Skip to content

Commit 110a7c4

Browse files
committed
AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See pytorch/pytorch#43992
1 parent c2cd1a3 commit 110a7c4

File tree

8 files changed

+33
-66
lines changed

8 files changed

+33
-66
lines changed

timm/models/efficientnet_blocks.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module):
106106
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
107107
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
108108
super(SqueezeExcite, self).__init__()
109-
self.gate_fn = gate_fn
110109
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
111-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
112110
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
113111
self.act1 = act_layer(inplace=True)
114112
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
113+
self.gate_fn = gate_fn
115114

116115
def forward(self, x):
117-
x_se = self.avg_pool(x)
116+
x_se = x.mean((2, 3), keepdim=True)
118117
x_se = self.conv_reduce(x_se)
119118
x_se = self.act1(x_se)
120119
x_se = self.conv_expand(x_se)
121-
x = x * self.gate_fn(x_se)
122-
return x
120+
return x * self.gate_fn(x_se)
123121

124122

125123
class ConvBnAct(nn.Module):

timm/models/layers/cbam.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch import nn as nn
13+
import torch.nn.functional as F
1314
from .conv_bn_act import ConvBnAct
1415

1516

@@ -18,15 +19,13 @@ class ChannelAttn(nn.Module):
1819
"""
1920
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
2021
super(ChannelAttn, self).__init__()
21-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
22-
self.max_pool = nn.AdaptiveMaxPool2d(1)
2322
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
2423
self.act = act_layer(inplace=True)
2524
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
2625

2726
def forward(self, x):
28-
x_avg = self.avg_pool(x)
29-
x_max = self.max_pool(x)
27+
x_avg = x.mean((2, 3), keepdim=True)
28+
x_max = F.adaptive_max_pool2d(x, 1)
3029
x_avg = self.fc2(self.act(self.fc1(x_avg)))
3130
x_max = self.fc2(self.act(self.fc1(x_max)))
3231
x_attn = x_avg + x_max
@@ -40,7 +39,7 @@ def __init__(self, channels, reduction=16):
4039
super(LightChannelAttn, self).__init__(channels, reduction)
4140

4241
def forward(self, x):
43-
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
42+
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
4443
x_attn = self.fc2(self.act(self.fc1(x_pool)))
4544
return x * x_attn.sigmoid()
4645

timm/models/layers/eca.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,15 @@ class EcaModule(nn.Module):
5252
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
5353
super(EcaModule, self).__init__()
5454
assert kernel_size % 2 == 1
55-
5655
if channels is not None:
5756
t = int(abs(math.log(channels, 2) + beta) / gamma)
5857
kernel_size = max(t if t % 2 else t + 1, 3)
5958

60-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
6159
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
6260

6361
def forward(self, x):
64-
# Feature descriptor on the global spatial information
65-
y = self.avg_pool(x)
66-
# Reshape for convolution
67-
y = y.view(x.shape[0], 1, -1)
68-
# Two different branches of ECA module
62+
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
6963
y = self.conv(y)
70-
# Multi-scale information fusion
7164
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
7265
return x * y.expand_as(x)
7366

@@ -95,30 +88,20 @@ class CecaModule(nn.Module):
9588
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
9689
super(CecaModule, self).__init__()
9790
assert kernel_size % 2 == 1
98-
9991
if channels is not None:
10092
t = int(abs(math.log(channels, 2) + beta) / gamma)
10193
kernel_size = max(t if t % 2 else t + 1, 3)
10294

103-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
104-
#pytorch circular padding mode is buggy as of pytorch 1.4
105-
#see https://github.com/pytorch/pytorch/pull/17240
106-
107-
#implement manual circular padding
95+
# PyTorch circular padding mode is buggy as of pytorch 1.4
96+
# see https://github.com/pytorch/pytorch/pull/17240
97+
# implement manual circular padding
10898
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
10999
self.padding = (kernel_size - 1) // 2
110100

111101
def forward(self, x):
112-
# Feature descriptor on the global spatial information
113-
y = self.avg_pool(x)
114-
102+
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
115103
# Manually implement circular padding, F.pad does not seemed to be bugged
116-
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
117-
118-
# Two different branches of ECA module
104+
y = F.pad(y, (self.padding, self.padding), mode='circular')
119105
y = self.conv(y)
120-
121-
# Multi-scale information fusion
122106
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
123-
124107
return x * y.expand_as(x)

timm/models/layers/se.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
11
from torch import nn as nn
2-
from .create_act import get_act_fn
2+
from .create_act import create_act_layer
33

44

55
class SEModule(nn.Module):
66

77
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
8-
gate_fn='sigmoid'):
8+
gate_layer='sigmoid'):
99
super(SEModule, self).__init__()
10-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
1110
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
12-
self.fc1 = nn.Conv2d(
13-
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
11+
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
1412
self.act = act_layer(inplace=True)
15-
self.fc2 = nn.Conv2d(
16-
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
17-
self.gate_fn = get_act_fn(gate_fn)
13+
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
14+
self.gate = create_act_layer(gate_layer)
1815

1916
def forward(self, x):
20-
x_se = self.avg_pool(x)
17+
x_se = x.mean((2, 3), keepdim=True)
2118
x_se = self.fc1(x_se)
2219
x_se = self.act(x_se)
2320
x_se = self.fc2(x_se)
24-
return x * self.gate_fn(x_se)
21+
return x * self.gate(x_se)
2522

2623

2724
class EffectiveSEModule(nn.Module):
2825
""" 'Effective Squeeze-Excitation
2926
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
3027
"""
31-
def __init__(self, channels, gate_fn='hard_sigmoid'):
28+
def __init__(self, channels, gate_layer='hard_sigmoid'):
3229
super(EffectiveSEModule, self).__init__()
33-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
3430
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
35-
self.gate_fn = get_act_fn(gate_fn)
31+
self.gate = create_act_layer(gate_layer, inplace=True)
3632

3733
def forward(self, x):
38-
x_se = self.avg_pool(x)
34+
x_se = x.mean((2, 3), keepdim=True)
3935
x_se = self.fc(x_se)
40-
return x * self.gate_fn(x_se, inplace=True)
36+
return x * self.gate(x_se)

timm/models/layers/selective_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ def __init__(self, channels, num_paths=2, attn_channels=32,
2727
"""
2828
super(SelectiveKernelAttn, self).__init__()
2929
self.num_paths = num_paths
30-
self.pool = nn.AdaptiveAvgPool2d(1)
3130
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
3231
self.bn = norm_layer(attn_channels)
3332
self.act = act_layer(inplace=True)
3433
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
3534

3635
def forward(self, x):
3736
assert x.shape[1] == self.num_paths
38-
x = torch.sum(x, dim=1)
39-
x = self.pool(x)
37+
x = x.sum(1).mean((2, 3), keepdim=True)
4038
x = self.fc_reduce(x)
4139
x = self.bn(x)
4240
x = self.act(x)

timm/models/rexnet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,15 @@ class SEWithNorm(nn.Module):
5959
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
6060
gate_layer='sigmoid'):
6161
super(SEWithNorm, self).__init__()
62-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
6362
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
64-
self.fc1 = nn.Conv2d(
65-
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
63+
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
6664
self.bn = nn.BatchNorm2d(reduction_channels)
6765
self.act = act_layer(inplace=True)
68-
self.fc2 = nn.Conv2d(
69-
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
66+
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
7067
self.gate = create_act_layer(gate_layer)
7168

7269
def forward(self, x):
73-
x_se = self.avg_pool(x)
70+
x_se = x.mean((2, 3), keepdim=True)
7471
x_se = self.fc1(x_se)
7572
x_se = self.bn(x_se)
7673
x_se = self.act(x_se)

timm/models/senet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,14 @@ class SEModule(nn.Module):
7171

7272
def __init__(self, channels, reduction):
7373
super(SEModule, self).__init__()
74-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
75-
self.fc1 = nn.Conv2d(
76-
channels, channels // reduction, kernel_size=1, padding=0)
74+
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
7775
self.relu = nn.ReLU(inplace=True)
78-
self.fc2 = nn.Conv2d(
79-
channels // reduction, channels, kernel_size=1, padding=0)
76+
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
8077
self.sigmoid = nn.Sigmoid()
8178

8279
def forward(self, x):
8380
module_input = x
84-
x = self.avg_pool(x)
81+
x = x.mean((2, 3), keepdim=True)
8582
x = self.fc1(x)
8683
x = self.relu(x)
8784
x = self.fc2(x)

timm/models/tresnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,9 @@ def __init__(self, flatten=False):
5656

5757
def forward(self, x):
5858
if self.flatten:
59-
in_size = x.size()
60-
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
59+
return x.mean((2, 3))
6160
else:
62-
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
61+
return x.mean((2, 3), keepdim=True)
6362

6463
def feat_mult(self):
6564
return 1

0 commit comments

Comments
 (0)