Skip to content

Commit 25a2423

Browse files
committed
models: Refactor to simplify code
1 parent 75e64e0 commit 25a2423

File tree

1 file changed

+6
-33
lines changed

1 file changed

+6
-33
lines changed

models/backbone/xception.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5252
out = self.relu1(x)
5353
out = self.sepconv1(out)
5454
out = self.relu2(out)
55-
out = self.sepconv2(out) # hook
55+
out = self.sepconv2(out) # forward hook
5656
out = self.relu3(out)
5757
out = self.sepconv3(out)
5858

@@ -86,22 +86,10 @@ def __init__(self, output_stride: int):
8686
self.block3 = Block(256, 728, entry_block3_stride, dilation=1, skip_connection_type='conv')
8787

8888
# Middle flow
89-
self.block4 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
90-
self.block5 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
91-
self.block6 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
92-
self.block7 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
93-
self.block8 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
94-
self.block9 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
95-
self.block10 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
96-
self.block11 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
97-
self.block12 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
98-
self.block13 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
99-
self.block14 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
100-
self.block15 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
101-
self.block16 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
102-
self.block17 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
103-
self.block18 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
104-
self.block19 = Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum')
89+
layers = []
90+
for _ in range(16):
91+
layers.append(Block(728, 728, 1, middle_block_dilation, skip_connection_type='sum'))
92+
self.middle_flow = nn.Sequential(*layers)
10593

10694
# Exit flow
10795
self.exit_flow = nn.Sequential(
@@ -131,22 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
131119
x = self.block3(x)
132120

133121
# Middle flow
134-
x = self.block4(x)
135-
x = self.block5(x)
136-
x = self.block6(x)
137-
x = self.block7(x)
138-
x = self.block8(x)
139-
x = self.block9(x)
140-
x = self.block10(x)
141-
x = self.block11(x)
142-
x = self.block12(x)
143-
x = self.block13(x)
144-
x = self.block14(x)
145-
x = self.block15(x)
146-
x = self.block16(x)
147-
x = self.block17(x)
148-
x = self.block18(x)
149-
x = self.block19(x)
122+
x = self.middle_flow(x)
150123

151124
# Exit flow
152125
x = self.exit_flow(x)

0 commit comments

Comments
 (0)