Skip to content

Commit 5e6dbba

Browse files
committed
Add CBAM for experimentation
1 parent d725991 commit 5e6dbba

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

timm/models/layers/cbam.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
""" CBAM (sort-of) Attention
2+
3+
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
4+
5+
Hacked together by Ross Wightman
6+
"""
7+
8+
import torch
9+
from torch import nn as nn
10+
from .conv_bn_act import ConvBnAct
11+
12+
13+
class ChannelAttn(nn.Module):
14+
""" Original CBAM channel attention module, currently avg + max pool variant only.
15+
"""
16+
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
17+
super(ChannelAttn, self).__init__()
18+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
19+
self.max_pool = nn.AdaptiveMaxPool2d(1)
20+
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
21+
self.act = act_layer(inplace=True)
22+
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
23+
24+
def forward(self, x):
25+
x_avg = self.avg_pool(x)
26+
x_max = self.max_pool(x)
27+
x_avg = self.fc2(self.act(self.fc1(x_avg)))
28+
x_max = self.fc2(self.act(self.fc1(x_max)))
29+
x_attn = x_avg + x_max
30+
return x * x_attn.sigmoid()
31+
32+
33+
class LightChannelAttn(ChannelAttn):
34+
"""An experimental 'lightweight' that sums avg + max pool first
35+
"""
36+
def __init__(self, channels, reduction=16):
37+
super(LightChannelAttn, self).__init__(channels, reduction)
38+
39+
def forward(self, x):
40+
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
41+
x_attn = self.fc2(self.act(self.fc1(x_pool)))
42+
return x * x_attn.sigmoid()
43+
44+
45+
class SpatialAttn(nn.Module):
46+
""" Original CBAM spatial attention module
47+
"""
48+
def __init__(self, kernel_size=7):
49+
super(SpatialAttn, self).__init__()
50+
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
51+
52+
def forward(self, x):
53+
x_avg = torch.mean(x, dim=1, keepdim=True)
54+
x_max = torch.max(x, dim=1, keepdim=True)[0]
55+
x_attn = torch.cat([x_avg, x_max], dim=1)
56+
x_attn = self.conv(x_attn)
57+
return x * x_attn.sigmoid()
58+
59+
60+
class LightSpatialAttn(nn.Module):
61+
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
62+
"""
63+
def __init__(self, kernel_size=7):
64+
super(LightSpatialAttn, self).__init__()
65+
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
66+
67+
def forward(self, x):
68+
x_avg = torch.mean(x, dim=1, keepdim=True)
69+
x_max = torch.max(x, dim=1, keepdim=True)[0]
70+
x_attn = 0.5 * x_avg + 0.5 * x_max
71+
x_attn = self.conv(x_attn)
72+
return x * x_attn.sigmoid()
73+
74+
75+
class CbamModule(nn.Module):
76+
def __init__(self, channels, spatial_kernel_size=7):
77+
super(CbamModule, self).__init__()
78+
self.channel = ChannelAttn(channels)
79+
self.spatial = SpatialAttn(spatial_kernel_size)
80+
81+
def forward(self, x):
82+
x = self.channel(x)
83+
x = self.spatial(x)
84+
return x
85+
86+
87+
class LightCbamModule(nn.Module):
88+
def __init__(self, channels, spatial_kernel_size=7):
89+
super(LightCbamModule, self).__init__()
90+
self.channel = LightChannelAttn(channels)
91+
self.spatial = LightSpatialAttn(spatial_kernel_size)
92+
93+
def forward(self, x):
94+
x = self.channel(x)
95+
x = self.spatial(x)
96+
return x
97+

timm/models/layers/create_attn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from .se import SEModule
77
from .eca import EcaModule, CecaModule
8+
from .cbam import CbamModule, LightCbamModule
89

910

1011
def create_attn(attn_type, channels, **kwargs):
@@ -18,6 +19,10 @@ def create_attn(attn_type, channels, **kwargs):
1819
module_cls = EcaModule
1920
elif attn_type == 'eca':
2021
module_cls = CecaModule
22+
elif attn_type == 'cbam':
23+
module_cls = CbamModule
24+
elif attn_type == 'lcbam':
25+
module_cls = LightCbamModule
2126
else:
2227
assert False, "Invalid attn module (%s)" % attn_type
2328
elif isinstance(attn_type, bool):

0 commit comments

Comments
 (0)