Skip to content

Commit 9286d9f

Browse files
committed
Merge branch 'main' of D:\github\change_detection.pytorch with conflicts.
1 parent cb88399 commit 9286d9f

File tree

4 files changed

+229
-16
lines changed

4 files changed

+229
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Default ignored files
22
__pycache__/
33
.idea/
4+
.pytest_cache
45
changelog.md
56
test_tt.py
67
change_detection_pytorch/datasets/PRCV_CD.py

change_detection_pytorch/base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .heads import (
1010
SegmentationHead,
1111
ClassificationHead,
12-
)
12+
SegmentationOCRHead,
13+
)

change_detection_pytorch/base/heads.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch.nn as nn
2-
from .modules import Flatten, Activation
2+
from .modules import Flatten, Activation, OCR
33

44

55
class SegmentationHead(nn.Sequential):
@@ -11,6 +11,21 @@ def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, up
1111
super().__init__(conv2d, upsampling, activation)
1212

1313

14+
class SegmentationOCRHead(nn.Module):
15+
16+
def __init__(self, in_channels, out_channels, activation=None, upsampling=1, align_corners=True):
17+
super().__init__()
18+
self.ocr_head = OCR(in_channels, out_channels)
19+
self.upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity()
20+
self.activation = Activation(activation)
21+
22+
def forward(self, x):
23+
coarse_pre, pre = self.ocr_head(x)
24+
coarse_pre = self.activation(self.upsampling(coarse_pre))
25+
pre = self.activation(self.upsampling(pre))
26+
return [coarse_pre, pre]
27+
28+
1429
class ClassificationHead(nn.Sequential):
1530

1631
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):

change_detection_pytorch/base/modules.py

Lines changed: 210 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
35

46
try:
57
from inplace_abn import InPlaceABN
@@ -140,22 +142,218 @@ def forward(self, x):
140142
return out
141143

142144

143-
class SEModule(nn.Module):
144-
def __init__(self, in_channels, reduction=16):
145-
super(SEModule, self).__init__()
146-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
147-
self.fc = nn.Sequential(
148-
nn.Linear(in_channels, in_channels // reduction, bias=False),
145+
class ModuleHelper:
146+
147+
@staticmethod
148+
def BNReLU(num_features, bn_type=None, **kwargs):
149+
return nn.Sequential(
150+
nn.BatchNorm2d(num_features, **kwargs),
151+
nn.ReLU()
152+
)
153+
154+
@staticmethod
155+
def BatchNorm2d(*args, **kwargs):
156+
return nn.BatchNorm2d
157+
158+
159+
class SpatialGather_Module(nn.Module):
160+
"""
161+
Aggregate the context features according to the initial
162+
predicted probability distribution.
163+
Employ the soft-weighted method to aggregate the context.
164+
"""
165+
def __init__(self, cls_num=0, scale=1):
166+
super(SpatialGather_Module, self).__init__()
167+
self.cls_num = cls_num
168+
self.scale = scale
169+
170+
def forward(self, feats, probs):
171+
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
172+
probs = probs.view(batch_size, c, -1)
173+
feats = feats.view(batch_size, feats.size(1), -1)
174+
feats = feats.permute(0, 2, 1) # batch x hw x c
175+
probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw
176+
ocr_context = torch.matmul(probs, feats)\
177+
.permute(0, 2, 1).unsqueeze(3)# batch x k x c
178+
return ocr_context
179+
180+
181+
class _ObjectAttentionBlock(nn.Module):
182+
'''
183+
The basic implementation for object context block
184+
Input:
185+
N X C X H X W
186+
Parameters:
187+
in_channels : the dimension of the input feature map
188+
key_channels : the dimension after the key/query transform
189+
scale : choose the scale to downsample the input feature maps (save memory cost)
190+
bn_type : specify the bn type
191+
Return:
192+
N X C X H X W
193+
'''
194+
def __init__(self,
195+
in_channels,
196+
key_channels,
197+
scale=1,
198+
bn_type=None):
199+
super(_ObjectAttentionBlock, self).__init__()
200+
self.scale = scale
201+
self.in_channels = in_channels
202+
self.key_channels = key_channels
203+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
204+
self.f_pixel = nn.Sequential(
205+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
206+
kernel_size=1, stride=1, padding=0, bias=False),
207+
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
208+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
209+
kernel_size=1, stride=1, padding=0, bias=False),
210+
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
211+
)
212+
self.f_object = nn.Sequential(
213+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
214+
kernel_size=1, stride=1, padding=0, bias=False),
215+
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
216+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
217+
kernel_size=1, stride=1, padding=0, bias=False),
218+
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
219+
)
220+
self.f_down = nn.Sequential(
221+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
222+
kernel_size=1, stride=1, padding=0, bias=False),
223+
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
224+
)
225+
self.f_up = nn.Sequential(
226+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
227+
kernel_size=1, stride=1, padding=0, bias=False),
228+
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
229+
)
230+
231+
def forward(self, x, proxy):
232+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
233+
if self.scale > 1:
234+
x = self.pool(x)
235+
236+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
237+
query = query.permute(0, 2, 1)
238+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
239+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
240+
value = value.permute(0, 2, 1)
241+
242+
sim_map = torch.matmul(query, key)
243+
sim_map = (self.key_channels**-.5) * sim_map
244+
sim_map = F.softmax(sim_map, dim=-1)
245+
246+
# add bg context ...
247+
context = torch.matmul(sim_map, value)
248+
context = context.permute(0, 2, 1).contiguous()
249+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
250+
context = self.f_up(context)
251+
if self.scale > 1:
252+
context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True)
253+
254+
return context
255+
256+
257+
class ObjectAttentionBlock2D(_ObjectAttentionBlock):
258+
def __init__(self,
259+
in_channels,
260+
key_channels,
261+
scale=1,
262+
bn_type=None):
263+
super(ObjectAttentionBlock2D, self).__init__(in_channels,
264+
key_channels,
265+
scale,
266+
bn_type=bn_type)
267+
268+
269+
class SpatialOCR_Module(nn.Module):
270+
"""
271+
Implementation of the OCR module:
272+
We aggregate the global object representation to update the representation for each pixel.
273+
"""
274+
def __init__(self,
275+
in_channels,
276+
key_channels,
277+
out_channels,
278+
scale=1,
279+
dropout=0.1,
280+
bn_type=None):
281+
super(SpatialOCR_Module, self).__init__()
282+
self.object_context_block = ObjectAttentionBlock2D(in_channels,
283+
key_channels,
284+
scale,
285+
bn_type)
286+
_in_channels = 2 * in_channels
287+
288+
self.conv_bn_dropout = nn.Sequential(
289+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
290+
ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
291+
nn.Dropout2d(dropout)
292+
)
293+
294+
def forward(self, feats, proxy_feats):
295+
context = self.object_context_block(feats, proxy_feats)
296+
297+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
298+
299+
return output
300+
301+
302+
class OCR(nn.Module):
303+
"""
304+
Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation
305+
https://arxiv.org/pdf/1909.11065.pdf
306+
"""
307+
def __init__(self, in_channels, num_classes, ocr_mid_channels=512, ocr_key_channels=256):
308+
309+
super().__init__()
310+
pre_stage_channels = in_channels
311+
last_inp_channels = np.int(np.sum(pre_stage_channels))
312+
313+
self.conv3x3_ocr = nn.Sequential(
314+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
315+
kernel_size=3, stride=1, padding=1),
316+
nn.BatchNorm2d(ocr_mid_channels),
149317
nn.ReLU(inplace=True),
150-
nn.Linear(in_channels // reduction, in_channels, bias=False),
151-
nn.Sigmoid()
318+
)
319+
self.ocr_gather_head = SpatialGather_Module(num_classes)
320+
321+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
322+
key_channels=ocr_key_channels,
323+
out_channels=ocr_mid_channels,
324+
scale=1,
325+
dropout=0.05,
326+
)
327+
self.cls_head = nn.Conv2d(
328+
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
329+
330+
self.aux_head = nn.Sequential(
331+
nn.Conv2d(last_inp_channels, last_inp_channels,
332+
kernel_size=1, stride=1, padding=0),
333+
nn.BatchNorm2d(last_inp_channels),
334+
nn.ReLU(inplace=True),
335+
nn.Conv2d(last_inp_channels, num_classes,
336+
kernel_size=1, stride=1, padding=0, bias=True)
152337
)
153338

154339
def forward(self, x):
155-
b, c, _, _ = x.size()
156-
y = self.avg_pool(x).view(b, c)
157-
y = self.fc(y).view(b, c, 1, 1)
158-
return x * y.expand_as(x)
340+
341+
out_aux_seg = []
342+
343+
# ocr
344+
out_aux = self.aux_head(x)
345+
# compute contrast feature
346+
feats = self.conv3x3_ocr(x)
347+
348+
context = self.ocr_gather_head(feats, out_aux)
349+
feats = self.ocr_distri_head(feats, context)
350+
351+
out = self.cls_head(feats)
352+
353+
out_aux_seg.append(out_aux)
354+
out_aux_seg.append(out)
355+
356+
return out_aux_seg
159357

160358

161359
class ArgMax(nn.Module):
@@ -214,8 +412,6 @@ def __init__(self, name, **params):
214412
self.attention = CBAMSpatial(**params)
215413
elif name == 'cbam':
216414
self.attention = CBAM(**params)
217-
elif name == 'se':
218-
self.attention = SEModule(**params)
219415
else:
220416
raise ValueError("Attention {} is not implemented".format(name))
221417

0 commit comments

Comments
 (0)