|
14 | 14 | import torch.nn.functional as F
|
15 | 15 |
|
16 | 16 | from .helpers import build_model_with_cfg
|
17 |
| -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead |
| 17 | +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule |
18 | 18 | from .registry import register_model
|
19 | 19 |
|
20 | 20 | __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
@@ -49,40 +49,6 @@ def _cfg(url='', **kwargs):
|
49 | 49 | }
|
50 | 50 |
|
51 | 51 |
|
52 |
| -class FastGlobalAvgPool2d(nn.Module): |
53 |
| - def __init__(self, flatten=False): |
54 |
| - super(FastGlobalAvgPool2d, self).__init__() |
55 |
| - self.flatten = flatten |
56 |
| - |
57 |
| - def forward(self, x): |
58 |
| - if self.flatten: |
59 |
| - return x.mean((2, 3)) |
60 |
| - else: |
61 |
| - return x.mean((2, 3), keepdim=True) |
62 |
| - |
63 |
| - def feat_mult(self): |
64 |
| - return 1 |
65 |
| - |
66 |
| - |
67 |
| -class FastSEModule(nn.Module): |
68 |
| - |
69 |
| - def __init__(self, channels, reduction_channels, inplace=True): |
70 |
| - super(FastSEModule, self).__init__() |
71 |
| - self.avg_pool = FastGlobalAvgPool2d() |
72 |
| - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) |
73 |
| - self.relu = nn.ReLU(inplace=inplace) |
74 |
| - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) |
75 |
| - self.activation = nn.Sigmoid() |
76 |
| - |
77 |
| - def forward(self, x): |
78 |
| - x_se = self.avg_pool(x) |
79 |
| - x_se2 = self.fc1(x_se) |
80 |
| - x_se2 = self.relu(x_se2) |
81 |
| - x_se = self.fc2(x_se2) |
82 |
| - x_se = self.activation(x_se) |
83 |
| - return x * x_se |
84 |
| - |
85 |
| - |
86 | 52 | def IABN2Float(module: nn.Module) -> nn.Module:
|
87 | 53 | """If `module` is IABN don't use half precision."""
|
88 | 54 | if isinstance(module, InplaceAbn):
|
@@ -119,8 +85,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_
|
119 | 85 | self.relu = nn.ReLU(inplace=True)
|
120 | 86 | self.downsample = downsample
|
121 | 87 | self.stride = stride
|
122 |
| - reduce_layer_planes = max(planes * self.expansion // 4, 64) |
123 |
| - self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None |
| 88 | + reduction_chs = max(planes * self.expansion // 4, 64) |
| 89 | + self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None |
124 | 90 |
|
125 | 91 | def forward(self, x):
|
126 | 92 | if self.downsample is not None:
|
@@ -159,8 +125,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True,
|
159 | 125 | conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
160 | 126 | aa_layer(channels=planes, filt_size=3, stride=2))
|
161 | 127 |
|
162 |
| - reduce_layer_planes = max(planes * self.expansion // 8, 64) |
163 |
| - self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None |
| 128 | + reduction_chs = max(planes * self.expansion // 8, 64) |
| 129 | + self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None |
164 | 130 |
|
165 | 131 | self.conv3 = conv2d_iabn(
|
166 | 132 | planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
@@ -189,7 +155,7 @@ def forward(self, x):
|
189 | 155 |
|
190 | 156 | class TResNet(nn.Module):
|
191 | 157 | def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
192 |
| - global_pool='avg', drop_rate=0.): |
| 158 | + global_pool='fast', drop_rate=0.): |
193 | 159 | self.num_classes = num_classes
|
194 | 160 | self.drop_rate = drop_rate
|
195 | 161 | super(TResNet, self).__init__()
|
@@ -272,7 +238,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non
|
272 | 238 | def get_classifier(self):
|
273 | 239 | return self.head.fc
|
274 | 240 |
|
275 |
| - def reset_classifier(self, num_classes, global_pool='avg'): |
| 241 | + def reset_classifier(self, num_classes, global_pool='fast'): |
276 | 242 | self.head = ClassifierHead(
|
277 | 243 | self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
278 | 244 |
|
|
0 commit comments