Skip to content

Commit 8a19ec7

Browse files
committed
models: Refactor to simplify code
1 parent 25a2423 commit 8a19ec7

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

models/backbone/xception.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4949
else:
5050
shortcut = x
5151

52-
out = self.relu1(x)
53-
out = self.sepconv1(out)
54-
out = self.relu2(out)
55-
out = self.sepconv2(out) # forward hook
56-
out = self.relu3(out)
57-
out = self.sepconv3(out)
52+
x = self.relu1(x)
53+
x = self.sepconv1(x)
54+
x = self.relu2(x)
55+
x = self.sepconv2(x) # forward hook
56+
x = self.relu3(x)
57+
x = self.sepconv3(x)
5858

59-
out += shortcut
60-
return out
59+
x += shortcut
60+
return x
6161

6262

6363
class Xception(nn.Module):
@@ -75,12 +75,16 @@ def __init__(self, output_stride: int):
7575
raise NotImplementedError('Wrong output_stride.')
7676

7777
# Entry flow
78-
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
79-
self.bn1 = nn.BatchNorm2d(32)
80-
self.relu1 = nn.ReLU(inplace=True)
81-
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
82-
self.bn2 = nn.BatchNorm2d(64)
83-
self.relu2 = nn.ReLU(inplace=True)
78+
self.conv1 = nn.Sequential(
79+
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
80+
nn.BatchNorm2d(32),
81+
nn.ReLU(inplace=True)
82+
)
83+
self.conv2 = nn.Sequential(
84+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
85+
nn.BatchNorm2d(64),
86+
nn.ReLU(inplace=True)
87+
)
8488
self.block1 = Block(64, 128, 2, dilation=1, skip_connection_type='conv')
8589
self.block2 = Block(128, 256, 2, dilation=1, skip_connection_type='conv')
8690
self.block3 = Block(256, 728, entry_block3_stride, dilation=1, skip_connection_type='conv')
@@ -108,20 +112,12 @@ def __init__(self, output_stride: int):
108112
def forward(self, x: torch.Tensor) -> torch.Tensor:
109113
# Entry flow
110114
x = self.conv1(x)
111-
x = self.bn1(x)
112-
x = self.relu1(x)
113115
x = self.conv2(x)
114-
x = self.bn2(x)
115-
x = self.relu2(x)
116116
x = self.block1(x)
117117
x = self.block2(x)
118-
self.low_level_feature.append(self.block2.hook_layer)
119118
x = self.block3(x)
120119

121-
# Middle flow
122120
x = self.middle_flow(x)
123-
124-
# Exit flow
125121
x = self.exit_flow(x)
126122
return x
127123

0 commit comments

Comments
 (0)