Skip to content

Commit fea645d

Browse files
committed
optimize fusion_form
1 parent e33a2c9 commit fea645d

File tree

10 files changed

+17
-10
lines changed

10 files changed

+17
-10
lines changed

change_detection_pytorch/base/decoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33

44
class Decoder(torch.nn.Module):
5+
# TODO: support learnable fusion modules
6+
def __init__(self):
7+
super().__init__()
8+
self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"],
9+
"2to2_fusion": ["concat"]}
510

611
def fusion(self, x1, x2, fusion_form="concat"):
712
"""Specify the form of feature fusion"""
@@ -10,6 +15,8 @@ def fusion(self, x1, x2, fusion_form="concat"):
1015
elif fusion_form == "sum":
1116
x = x1 + x2
1217
elif fusion_form == "diff":
18+
x = x2 - x1
19+
elif fusion_form == "abs_diff":
1320
x = torch.abs(x1 - x2)
1421
else:
1522
raise ValueError('the fusion form "{}" is not defined'.format(fusion_form))

change_detection_pytorch/deeplabv3/decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
class DeepLabV3Decoder(nn.Sequential, Decoder):
4343
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36), fusion_form="concat"):
4444
# adjust encoder channels according to fusion form
45-
if fusion_form == "concat":
45+
if fusion_form in self.FUSION_DIC["2to2_fusion"]:
4646
in_channels = in_channels * 2
4747

4848
super().__init__(
@@ -77,7 +77,7 @@ def __init__(
7777

7878
# adjust encoder channels according to fusion form
7979
self.fusion_form = fusion_form
80-
if self.fusion_form == "concat":
80+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
8181
encoder_channels = [ch*2 for ch in encoder_channels]
8282

8383
self.aspp = nn.Sequential(

change_detection_pytorch/fpn/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797

9898
# adjust encoder channels according to fusion form
9999
self.fusion_form = fusion_form
100-
if self.fusion_form == "concat":
100+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
101101
encoder_channels = [ch*2 for ch in encoder_channels]
102102

103103
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)

change_detection_pytorch/linknet/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252

5353
# adjust encoder channels according to fusion form
5454
self.fusion_form = fusion_form
55-
if self.fusion_form == "concat":
55+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
5656
encoder_channels = [ch*2 for ch in encoder_channels]
5757

5858
channels = list(encoder_channels) + [prefinal_channels]

change_detection_pytorch/manet/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166

167167
# adjust encoder channels according to fusion form
168168
self.fusion_form = fusion_form
169-
if self.fusion_form == "concat":
169+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
170170
skip_channels = [ch*2 for ch in skip_channels]
171171
in_channels[0] = in_channels[0] * 2
172172
head_channels = head_channels * 2

change_detection_pytorch/pan/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156

157157
# adjust encoder channels according to fusion form
158158
self.fusion_form = fusion_form
159-
if self.fusion_form == "concat":
159+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
160160
encoder_channels = [ch * 2 for ch in encoder_channels]
161161

162162
self.fpa = FPABlock(in_channels=encoder_channels[-1], out_channels=decoder_channels)

change_detection_pytorch/pspnet/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151

5252
# adjust encoder channels according to fusion form
5353
self.fusion_form = fusion_form
54-
if self.fusion_form == "concat":
54+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
5555
encoder_channels = [ch*2 for ch in encoder_channels]
5656

5757
self.psp = PSPModule(

change_detection_pytorch/unet/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696

9797
# adjust encoder channels according to fusion form
9898
self.fusion_form = fusion_form
99-
if self.fusion_form == "concat":
99+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
100100
skip_channels = [ch*2 for ch in skip_channels]
101101
in_channels[0] = in_channels[0] * 2
102102
head_channels = head_channels * 2

change_detection_pytorch/unetplusplus/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494

9595
# adjust encoder channels according to fusion form
9696
self.fusion_form = fusion_form
97-
if self.fusion_form == "concat":
97+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
9898
self.skip_channels = [ch*2 for ch in self.skip_channels]
9999
self.in_channels[0] = self.in_channels[0] * 2
100100

change_detection_pytorch/upernet/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898

9999
# adjust encoder channels according to fusion form
100100
self.fusion_form = fusion_form
101-
if self.fusion_form == "concat":
101+
if self.fusion_form in self.FUSION_DIC["2to2_fusion"]:
102102
encoder_channels = [ch*2 for ch in encoder_channels]
103103

104104
self.psp = PSPModule(

0 commit comments

Comments
 (0)