Skip to content

Commit 6e7b0df

Browse files
authored
Merge pull request #7 from dpxudong/master
add se attention
2 parents f409af4 + c8ff39c commit 6e7b0df

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

change_detection_pytorch/base/modules.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ def forward(self, x):
140140
return out
141141

142142

143+
class SEModule(nn.Module):
144+
def __init__(self, in_channels, reduction=16):
145+
super(SEModule, self).__init__()
146+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
147+
self.fc = nn.Sequential(
148+
nn.Linear(in_channels, in_channels // reduction, bias=False),
149+
nn.ReLU(inplace=True),
150+
nn.Linear(in_channels // reduction, in_channels, bias=False),
151+
nn.Sigmoid()
152+
)
153+
154+
def forward(self, x):
155+
b, c, _, _ = x.size()
156+
y = self.avg_pool(x).view(b, c)
157+
y = self.fc(y).view(b, c, 1, 1)
158+
return x * y.expand_as(x)
159+
160+
143161
class ArgMax(nn.Module):
144162

145163
def __init__(self, dim=None):
@@ -196,6 +214,8 @@ def __init__(self, name, **params):
196214
self.attention = CBAMSpatial(**params)
197215
elif name == 'cbam':
198216
self.attention = CBAM(**params)
217+
elif name == 'se':
218+
self.attention = SEModule(**params)
199219
else:
200220
raise ValueError("Attention {} is not implemented".format(name))
201221

0 commit comments

Comments
 (0)