@@ -19,28 +19,22 @@ def __init__(self, num_classes: int):
19
19
self .layer4 = resnet34 .layer4 # 512, 1/8
20
20
21
21
# Classifier
22
- self .classifier_2s = nn .Conv2d (128 , num_classes , kernel_size = 1 )
23
- self .classifier_4s = nn .Conv2d (256 , num_classes , kernel_size = 1 )
24
- self .classifier_8s = nn .Conv2d (512 , num_classes , kernel_size = 1 )
22
+ self .classifier = nn .Conv2d (512 , num_classes , kernel_size = 1 )
25
23
26
24
def forward (self , x ):
27
25
# Encoder
28
- initial_conv = self .initial_conv (x )
29
- layer1 = self .layer1 (initial_conv )
30
- layer2 = self .layer2 (layer1 )
31
- layer3 = self .layer3 (layer2 )
32
- layer4 = self .layer4 (layer3 )
26
+ x = self .initial_conv (x )
27
+ x = self .layer1 (x )
28
+ x = self .layer2 (x )
29
+ x = self .layer3 (x )
30
+ x = self .layer4 (x )
33
31
34
32
# Classifier
35
- classifier_8s = self .classifier_8s (layer4 )
36
- classifier_4s = self .classifier_4s (layer3 )
37
- classifier_2s = self .classifier_2s (layer2 )
33
+ x = self .classifier (x )
38
34
39
- # FCN
40
- classifier_4s += F .interpolate (classifier_8s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
41
- classifier_2s += F .interpolate (classifier_4s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
42
- out = F .interpolate (classifier_2s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
43
- return out
35
+ # Upsample
36
+ x = F .interpolate (x , scale_factor = 8 , mode = 'bilinear' , align_corners = False )
37
+ return x
44
38
45
39
def make_initial_conv (self , in_channels : int , out_channels : int ):
46
40
return nn .Sequential (
@@ -65,11 +59,11 @@ def load_backbone(num_classes: int, pretrained=False):
65
59
66
60
if __name__ == '__main__' :
67
61
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
68
- model = Backbone (8 ).to (device )
62
+ model = Backbone (20 ).to (device )
69
63
model .eval ()
70
64
71
- torchsummary .torchsummary .summary (model , (3 , 256 , 512 ))
65
+ torchsummary .torchsummary .summary (model , (3 , 400 , 800 ))
72
66
73
67
writer = torch .utils .tensorboard .SummaryWriter ('../runs' )
74
- writer .add_graph (model , torch .rand (1 , 3 , 256 , 512 ).to (device ))
68
+ writer .add_graph (model , torch .rand (1 , 3 , 400 , 800 ).to (device ))
75
69
writer .close ()
0 commit comments