@@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
52
out = self .relu1 (x )
53
53
out = self .sepconv1 (out )
54
54
out = self .relu2 (out )
55
- out = self .sepconv2 (out ) # hook
55
+ out = self .sepconv2 (out ) # forward hook
56
56
out = self .relu3 (out )
57
57
out = self .sepconv3 (out )
58
58
@@ -86,22 +86,10 @@ def __init__(self, output_stride: int):
86
86
self .block3 = Block (256 , 728 , entry_block3_stride , dilation = 1 , skip_connection_type = 'conv' )
87
87
88
88
# Middle flow
89
- self .block4 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
90
- self .block5 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
91
- self .block6 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
92
- self .block7 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
93
- self .block8 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
94
- self .block9 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
95
- self .block10 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
96
- self .block11 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
97
- self .block12 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
98
- self .block13 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
99
- self .block14 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
100
- self .block15 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
101
- self .block16 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
102
- self .block17 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
103
- self .block18 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
104
- self .block19 = Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' )
89
+ layers = []
90
+ for _ in range (16 ):
91
+ layers .append (Block (728 , 728 , 1 , middle_block_dilation , skip_connection_type = 'sum' ))
92
+ self .middle_flow = nn .Sequential (* layers )
105
93
106
94
# Exit flow
107
95
self .exit_flow = nn .Sequential (
@@ -131,22 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
119
x = self .block3 (x )
132
120
133
121
# Middle flow
134
- x = self .block4 (x )
135
- x = self .block5 (x )
136
- x = self .block6 (x )
137
- x = self .block7 (x )
138
- x = self .block8 (x )
139
- x = self .block9 (x )
140
- x = self .block10 (x )
141
- x = self .block11 (x )
142
- x = self .block12 (x )
143
- x = self .block13 (x )
144
- x = self .block14 (x )
145
- x = self .block15 (x )
146
- x = self .block16 (x )
147
- x = self .block17 (x )
148
- x = self .block18 (x )
149
- x = self .block19 (x )
122
+ x = self .middle_flow (x )
150
123
151
124
# Exit flow
152
125
x = self .exit_flow (x )
0 commit comments