|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
3 | 5 |
|
4 | 6 | try:
|
5 | 7 | from inplace_abn import InPlaceABN
|
@@ -140,22 +142,218 @@ def forward(self, x):
|
140 | 142 | return out
|
141 | 143 |
|
142 | 144 |
|
143 |
| -class SEModule(nn.Module): |
144 |
| - def __init__(self, in_channels, reduction=16): |
145 |
| - super(SEModule, self).__init__() |
146 |
| - self.avg_pool = nn.AdaptiveAvgPool2d(1) |
147 |
| - self.fc = nn.Sequential( |
148 |
| - nn.Linear(in_channels, in_channels // reduction, bias=False), |
| 145 | +class ModuleHelper: |
| 146 | + |
| 147 | + @staticmethod |
| 148 | + def BNReLU(num_features, bn_type=None, **kwargs): |
| 149 | + return nn.Sequential( |
| 150 | + nn.BatchNorm2d(num_features, **kwargs), |
| 151 | + nn.ReLU() |
| 152 | + ) |
| 153 | + |
| 154 | + @staticmethod |
| 155 | + def BatchNorm2d(*args, **kwargs): |
| 156 | + return nn.BatchNorm2d |
| 157 | + |
| 158 | + |
| 159 | +class SpatialGather_Module(nn.Module): |
| 160 | + """ |
| 161 | + Aggregate the context features according to the initial |
| 162 | + predicted probability distribution. |
| 163 | + Employ the soft-weighted method to aggregate the context. |
| 164 | + """ |
| 165 | + def __init__(self, cls_num=0, scale=1): |
| 166 | + super(SpatialGather_Module, self).__init__() |
| 167 | + self.cls_num = cls_num |
| 168 | + self.scale = scale |
| 169 | + |
| 170 | + def forward(self, feats, probs): |
| 171 | + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) |
| 172 | + probs = probs.view(batch_size, c, -1) |
| 173 | + feats = feats.view(batch_size, feats.size(1), -1) |
| 174 | + feats = feats.permute(0, 2, 1) # batch x hw x c |
| 175 | + probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw |
| 176 | + ocr_context = torch.matmul(probs, feats)\ |
| 177 | + .permute(0, 2, 1).unsqueeze(3)# batch x k x c |
| 178 | + return ocr_context |
| 179 | + |
| 180 | + |
| 181 | +class _ObjectAttentionBlock(nn.Module): |
| 182 | + ''' |
| 183 | + The basic implementation for object context block |
| 184 | + Input: |
| 185 | + N X C X H X W |
| 186 | + Parameters: |
| 187 | + in_channels : the dimension of the input feature map |
| 188 | + key_channels : the dimension after the key/query transform |
| 189 | + scale : choose the scale to downsample the input feature maps (save memory cost) |
| 190 | + bn_type : specify the bn type |
| 191 | + Return: |
| 192 | + N X C X H X W |
| 193 | + ''' |
| 194 | + def __init__(self, |
| 195 | + in_channels, |
| 196 | + key_channels, |
| 197 | + scale=1, |
| 198 | + bn_type=None): |
| 199 | + super(_ObjectAttentionBlock, self).__init__() |
| 200 | + self.scale = scale |
| 201 | + self.in_channels = in_channels |
| 202 | + self.key_channels = key_channels |
| 203 | + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) |
| 204 | + self.f_pixel = nn.Sequential( |
| 205 | + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, |
| 206 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 207 | + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), |
| 208 | + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, |
| 209 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 210 | + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), |
| 211 | + ) |
| 212 | + self.f_object = nn.Sequential( |
| 213 | + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, |
| 214 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 215 | + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), |
| 216 | + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, |
| 217 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 218 | + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), |
| 219 | + ) |
| 220 | + self.f_down = nn.Sequential( |
| 221 | + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, |
| 222 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 223 | + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), |
| 224 | + ) |
| 225 | + self.f_up = nn.Sequential( |
| 226 | + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, |
| 227 | + kernel_size=1, stride=1, padding=0, bias=False), |
| 228 | + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), |
| 229 | + ) |
| 230 | + |
| 231 | + def forward(self, x, proxy): |
| 232 | + batch_size, h, w = x.size(0), x.size(2), x.size(3) |
| 233 | + if self.scale > 1: |
| 234 | + x = self.pool(x) |
| 235 | + |
| 236 | + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) |
| 237 | + query = query.permute(0, 2, 1) |
| 238 | + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) |
| 239 | + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) |
| 240 | + value = value.permute(0, 2, 1) |
| 241 | + |
| 242 | + sim_map = torch.matmul(query, key) |
| 243 | + sim_map = (self.key_channels**-.5) * sim_map |
| 244 | + sim_map = F.softmax(sim_map, dim=-1) |
| 245 | + |
| 246 | + # add bg context ... |
| 247 | + context = torch.matmul(sim_map, value) |
| 248 | + context = context.permute(0, 2, 1).contiguous() |
| 249 | + context = context.view(batch_size, self.key_channels, *x.size()[2:]) |
| 250 | + context = self.f_up(context) |
| 251 | + if self.scale > 1: |
| 252 | + context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) |
| 253 | + |
| 254 | + return context |
| 255 | + |
| 256 | + |
| 257 | +class ObjectAttentionBlock2D(_ObjectAttentionBlock): |
| 258 | + def __init__(self, |
| 259 | + in_channels, |
| 260 | + key_channels, |
| 261 | + scale=1, |
| 262 | + bn_type=None): |
| 263 | + super(ObjectAttentionBlock2D, self).__init__(in_channels, |
| 264 | + key_channels, |
| 265 | + scale, |
| 266 | + bn_type=bn_type) |
| 267 | + |
| 268 | + |
| 269 | +class SpatialOCR_Module(nn.Module): |
| 270 | + """ |
| 271 | + Implementation of the OCR module: |
| 272 | + We aggregate the global object representation to update the representation for each pixel. |
| 273 | + """ |
| 274 | + def __init__(self, |
| 275 | + in_channels, |
| 276 | + key_channels, |
| 277 | + out_channels, |
| 278 | + scale=1, |
| 279 | + dropout=0.1, |
| 280 | + bn_type=None): |
| 281 | + super(SpatialOCR_Module, self).__init__() |
| 282 | + self.object_context_block = ObjectAttentionBlock2D(in_channels, |
| 283 | + key_channels, |
| 284 | + scale, |
| 285 | + bn_type) |
| 286 | + _in_channels = 2 * in_channels |
| 287 | + |
| 288 | + self.conv_bn_dropout = nn.Sequential( |
| 289 | + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), |
| 290 | + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), |
| 291 | + nn.Dropout2d(dropout) |
| 292 | + ) |
| 293 | + |
| 294 | + def forward(self, feats, proxy_feats): |
| 295 | + context = self.object_context_block(feats, proxy_feats) |
| 296 | + |
| 297 | + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) |
| 298 | + |
| 299 | + return output |
| 300 | + |
| 301 | + |
| 302 | +class OCR(nn.Module): |
| 303 | + """ |
| 304 | + Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation |
| 305 | + https://arxiv.org/pdf/1909.11065.pdf |
| 306 | + """ |
| 307 | + def __init__(self, in_channels, num_classes, ocr_mid_channels=512, ocr_key_channels=256): |
| 308 | + |
| 309 | + super().__init__() |
| 310 | + pre_stage_channels = in_channels |
| 311 | + last_inp_channels = np.int(np.sum(pre_stage_channels)) |
| 312 | + |
| 313 | + self.conv3x3_ocr = nn.Sequential( |
| 314 | + nn.Conv2d(last_inp_channels, ocr_mid_channels, |
| 315 | + kernel_size=3, stride=1, padding=1), |
| 316 | + nn.BatchNorm2d(ocr_mid_channels), |
149 | 317 | nn.ReLU(inplace=True),
|
150 |
| - nn.Linear(in_channels // reduction, in_channels, bias=False), |
151 |
| - nn.Sigmoid() |
| 318 | + ) |
| 319 | + self.ocr_gather_head = SpatialGather_Module(num_classes) |
| 320 | + |
| 321 | + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, |
| 322 | + key_channels=ocr_key_channels, |
| 323 | + out_channels=ocr_mid_channels, |
| 324 | + scale=1, |
| 325 | + dropout=0.05, |
| 326 | + ) |
| 327 | + self.cls_head = nn.Conv2d( |
| 328 | + ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) |
| 329 | + |
| 330 | + self.aux_head = nn.Sequential( |
| 331 | + nn.Conv2d(last_inp_channels, last_inp_channels, |
| 332 | + kernel_size=1, stride=1, padding=0), |
| 333 | + nn.BatchNorm2d(last_inp_channels), |
| 334 | + nn.ReLU(inplace=True), |
| 335 | + nn.Conv2d(last_inp_channels, num_classes, |
| 336 | + kernel_size=1, stride=1, padding=0, bias=True) |
152 | 337 | )
|
153 | 338 |
|
154 | 339 | def forward(self, x):
|
155 |
| - b, c, _, _ = x.size() |
156 |
| - y = self.avg_pool(x).view(b, c) |
157 |
| - y = self.fc(y).view(b, c, 1, 1) |
158 |
| - return x * y.expand_as(x) |
| 340 | + |
| 341 | + out_aux_seg = [] |
| 342 | + |
| 343 | + # ocr |
| 344 | + out_aux = self.aux_head(x) |
| 345 | + # compute contrast feature |
| 346 | + feats = self.conv3x3_ocr(x) |
| 347 | + |
| 348 | + context = self.ocr_gather_head(feats, out_aux) |
| 349 | + feats = self.ocr_distri_head(feats, context) |
| 350 | + |
| 351 | + out = self.cls_head(feats) |
| 352 | + |
| 353 | + out_aux_seg.append(out_aux) |
| 354 | + out_aux_seg.append(out) |
| 355 | + |
| 356 | + return out_aux_seg |
159 | 357 |
|
160 | 358 |
|
161 | 359 | class ArgMax(nn.Module):
|
@@ -214,8 +412,6 @@ def __init__(self, name, **params):
|
214 | 412 | self.attention = CBAMSpatial(**params)
|
215 | 413 | elif name == 'cbam':
|
216 | 414 | self.attention = CBAM(**params)
|
217 |
| - elif name == 'se': |
218 |
| - self.attention = SEModule(**params) |
219 | 415 | else:
|
220 | 416 | raise ValueError("Attention {} is not implemented".format(name))
|
221 | 417 |
|
|
0 commit comments