@@ -49,15 +49,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
49
else :
50
50
shortcut = x
51
51
52
- out = self .relu1 (x )
53
- out = self .sepconv1 (out )
54
- out = self .relu2 (out )
55
- out = self .sepconv2 (out ) # forward hook
56
- out = self .relu3 (out )
57
- out = self .sepconv3 (out )
52
+ x = self .relu1 (x )
53
+ x = self .sepconv1 (x )
54
+ x = self .relu2 (x )
55
+ x = self .sepconv2 (x ) # forward hook
56
+ x = self .relu3 (x )
57
+ x = self .sepconv3 (x )
58
58
59
- out += shortcut
60
- return out
59
+ x += shortcut
60
+ return x
61
61
62
62
63
63
class Xception (nn .Module ):
@@ -75,12 +75,16 @@ def __init__(self, output_stride: int):
75
75
raise NotImplementedError ('Wrong output_stride.' )
76
76
77
77
# Entry flow
78
- self .conv1 = nn .Conv2d (3 , 32 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
79
- self .bn1 = nn .BatchNorm2d (32 )
80
- self .relu1 = nn .ReLU (inplace = True )
81
- self .conv2 = nn .Conv2d (32 , 64 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
82
- self .bn2 = nn .BatchNorm2d (64 )
83
- self .relu2 = nn .ReLU (inplace = True )
78
+ self .conv1 = nn .Sequential (
79
+ nn .Conv2d (3 , 32 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False ),
80
+ nn .BatchNorm2d (32 ),
81
+ nn .ReLU (inplace = True )
82
+ )
83
+ self .conv2 = nn .Sequential (
84
+ nn .Conv2d (32 , 64 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
85
+ nn .BatchNorm2d (64 ),
86
+ nn .ReLU (inplace = True )
87
+ )
84
88
self .block1 = Block (64 , 128 , 2 , dilation = 1 , skip_connection_type = 'conv' )
85
89
self .block2 = Block (128 , 256 , 2 , dilation = 1 , skip_connection_type = 'conv' )
86
90
self .block3 = Block (256 , 728 , entry_block3_stride , dilation = 1 , skip_connection_type = 'conv' )
@@ -108,20 +112,12 @@ def __init__(self, output_stride: int):
108
112
def forward (self , x : torch .Tensor ) -> torch .Tensor :
109
113
# Entry flow
110
114
x = self .conv1 (x )
111
- x = self .bn1 (x )
112
- x = self .relu1 (x )
113
115
x = self .conv2 (x )
114
- x = self .bn2 (x )
115
- x = self .relu2 (x )
116
116
x = self .block1 (x )
117
117
x = self .block2 (x )
118
- self .low_level_feature .append (self .block2 .hook_layer )
119
118
x = self .block3 (x )
120
119
121
- # Middle flow
122
120
x = self .middle_flow (x )
123
-
124
- # Exit flow
125
121
x = self .exit_flow (x )
126
122
return x
127
123
0 commit comments