Skip to content

Commit 65a3bfa

Browse files
committed
backbone: Use only one classifier and interpolation
1 parent 6c6e59d commit 65a3bfa

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

models/backbone.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,22 @@ def __init__(self, num_classes: int):
1919
self.layer4 = resnet34.layer4 # 512, 1/8
2020

2121
# Classifier
22-
self.classifier_2s = nn.Conv2d(128, num_classes, kernel_size=1)
23-
self.classifier_4s = nn.Conv2d(256, num_classes, kernel_size=1)
24-
self.classifier_8s = nn.Conv2d(512, num_classes, kernel_size=1)
22+
self.classifier = nn.Conv2d(512, num_classes, kernel_size=1)
2523

2624
def forward(self, x):
2725
# Encoder
28-
initial_conv = self.initial_conv(x)
29-
layer1 = self.layer1(initial_conv)
30-
layer2 = self.layer2(layer1)
31-
layer3 = self.layer3(layer2)
32-
layer4 = self.layer4(layer3)
26+
x = self.initial_conv(x)
27+
x = self.layer1(x)
28+
x = self.layer2(x)
29+
x = self.layer3(x)
30+
x = self.layer4(x)
3331

3432
# Classifier
35-
classifier_8s = self.classifier_8s(layer4)
36-
classifier_4s = self.classifier_4s(layer3)
37-
classifier_2s = self.classifier_2s(layer2)
33+
x = self.classifier(x)
3834

39-
# FCN
40-
classifier_4s += F.interpolate(classifier_8s, scale_factor=2, mode='bilinear', align_corners=False)
41-
classifier_2s += F.interpolate(classifier_4s, scale_factor=2, mode='bilinear', align_corners=False)
42-
out = F.interpolate(classifier_2s, scale_factor=2, mode='bilinear', align_corners=False)
43-
return out
35+
# Upsample
36+
x = F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=False)
37+
return x
4438

4539
def make_initial_conv(self, in_channels: int, out_channels: int):
4640
return nn.Sequential(
@@ -65,11 +59,11 @@ def load_backbone(num_classes: int, pretrained=False):
6559

6660
if __name__ == '__main__':
6761
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68-
model = Backbone(8).to(device)
62+
model = Backbone(20).to(device)
6963
model.eval()
7064

71-
torchsummary.torchsummary.summary(model, (3, 256, 512))
65+
torchsummary.torchsummary.summary(model, (3, 400, 800))
7266

7367
writer = torch.utils.tensorboard.SummaryWriter('../runs')
74-
writer.add_graph(model, torch.rand(1, 3, 256, 512).to(device))
68+
writer.add_graph(model, torch.rand(1, 3, 400, 800).to(device))
7569
writer.close()

models/proposed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def make_channel_adjuster(self, in_channels: int, out_channels: int):
117117

118118
if __name__ == '__main__':
119119
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120-
model = Proposed(8).to(device)
120+
model = Proposed(20).to(device)
121121
model.eval()
122122

123-
torchsummary.torchsummary.summary(model, (3, 256, 512))
123+
torchsummary.torchsummary.summary(model, (3, 400, 800))
124124

125125
writer = torch.utils.tensorboard.SummaryWriter('../runs')
126-
writer.add_graph(model, torch.rand(1, 3, 256, 512).to(device))
126+
writer.add_graph(model, torch.rand(1, 3, 400, 800).to(device))
127127
writer.close()

models/unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ def double_conv(self, in_channels: int, out_channels: int):
6060

6161
if __name__ == '__main__':
6262
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63-
model = UNet(8).to(device)
63+
model = UNet(20).to(device)
6464
model.eval()
6565

66-
torchsummary.torchsummary.summary(model, (3, 256, 512))
66+
torchsummary.torchsummary.summary(model, (3, 400, 800))
6767

6868
writer = torch.utils.tensorboard.SummaryWriter('../runs')
69-
writer.add_graph(model, torch.rand(1, 3, 256, 512).to(device))
69+
writer.add_graph(model, torch.rand(1, 3, 400, 800).to(device))
7070
writer.close()

0 commit comments

Comments
 (0)