Skip to content

Commit 80c9d9c

Browse files
committed
Add 'fast' global pool option, remove redundant SEModule from tresnet, normal one is now 'fast'
1 parent 90a01f4 commit 80c9d9c

File tree

2 files changed

+21
-42
lines changed

2 files changed

+21
-42
lines changed

timm/models/layers/adaptive_avgmax_pool.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
4949
return x
5050

5151

52+
class FastAdaptiveAvgPool2d(nn.Module):
53+
def __init__(self, flatten=False):
54+
super(FastAdaptiveAvgPool2d, self).__init__()
55+
self.flatten = flatten
56+
57+
def forward(self, x):
58+
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
59+
60+
5261
class AdaptiveAvgMaxPool2d(nn.Module):
5362
def __init__(self, output_size=1):
5463
super(AdaptiveAvgMaxPool2d, self).__init__()
@@ -70,12 +79,16 @@ def forward(self, x):
7079
class SelectAdaptivePool2d(nn.Module):
7180
"""Selectable global pooling layer with dynamic input kernel size
7281
"""
73-
def __init__(self, output_size=1, pool_type='avg', flatten=False):
82+
def __init__(self, output_size=1, pool_type='fast', flatten=False):
7483
super(SelectAdaptivePool2d, self).__init__()
7584
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
7685
self.flatten = flatten
7786
if pool_type == '':
7887
self.pool = nn.Identity() # pass through
88+
elif pool_type == 'fast':
89+
assert output_size == 1
90+
self.pool = FastAdaptiveAvgPool2d(self.flatten)
91+
self.flatten = False
7992
elif pool_type == 'avg':
8093
self.pool = nn.AdaptiveAvgPool2d(output_size)
8194
elif pool_type == 'avgmax':

timm/models/tresnet.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn.functional as F
1515

1616
from .helpers import build_model_with_cfg
17-
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead
17+
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
1818
from .registry import register_model
1919

2020
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@@ -49,40 +49,6 @@ def _cfg(url='', **kwargs):
4949
}
5050

5151

52-
class FastGlobalAvgPool2d(nn.Module):
53-
def __init__(self, flatten=False):
54-
super(FastGlobalAvgPool2d, self).__init__()
55-
self.flatten = flatten
56-
57-
def forward(self, x):
58-
if self.flatten:
59-
return x.mean((2, 3))
60-
else:
61-
return x.mean((2, 3), keepdim=True)
62-
63-
def feat_mult(self):
64-
return 1
65-
66-
67-
class FastSEModule(nn.Module):
68-
69-
def __init__(self, channels, reduction_channels, inplace=True):
70-
super(FastSEModule, self).__init__()
71-
self.avg_pool = FastGlobalAvgPool2d()
72-
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
73-
self.relu = nn.ReLU(inplace=inplace)
74-
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
75-
self.activation = nn.Sigmoid()
76-
77-
def forward(self, x):
78-
x_se = self.avg_pool(x)
79-
x_se2 = self.fc1(x_se)
80-
x_se2 = self.relu(x_se2)
81-
x_se = self.fc2(x_se2)
82-
x_se = self.activation(x_se)
83-
return x * x_se
84-
85-
8652
def IABN2Float(module: nn.Module) -> nn.Module:
8753
"""If `module` is IABN don't use half precision."""
8854
if isinstance(module, InplaceAbn):
@@ -119,8 +85,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_
11985
self.relu = nn.ReLU(inplace=True)
12086
self.downsample = downsample
12187
self.stride = stride
122-
reduce_layer_planes = max(planes * self.expansion // 4, 64)
123-
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
88+
reduction_chs = max(planes * self.expansion // 4, 64)
89+
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
12490

12591
def forward(self, x):
12692
if self.downsample is not None:
@@ -159,8 +125,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True,
159125
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
160126
aa_layer(channels=planes, filt_size=3, stride=2))
161127

162-
reduce_layer_planes = max(planes * self.expansion // 8, 64)
163-
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
128+
reduction_chs = max(planes * self.expansion // 8, 64)
129+
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
164130

165131
self.conv3 = conv2d_iabn(
166132
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
@@ -189,7 +155,7 @@ def forward(self, x):
189155

190156
class TResNet(nn.Module):
191157
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
192-
global_pool='avg', drop_rate=0.):
158+
global_pool='fast', drop_rate=0.):
193159
self.num_classes = num_classes
194160
self.drop_rate = drop_rate
195161
super(TResNet, self).__init__()
@@ -272,7 +238,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non
272238
def get_classifier(self):
273239
return self.head.fc
274240

275-
def reset_classifier(self, num_classes, global_pool='avg'):
241+
def reset_classifier(self, num_classes, global_pool='fast'):
276242
self.head = ClassifierHead(
277243
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
278244

0 commit comments

Comments
 (0)