1
- # returns `DropBlock`s for each block of the ResNet
2
- function _drop_blocks (drop_block_prob = 0.0 )
3
- return [
4
- identity,
5
- identity,
6
- DropBlock (drop_block_prob, 5 , 0.25 ),
7
- DropBlock (drop_block_prob, 3 , 1.00 ),
8
- ]
9
- end
10
-
11
1
function downsample_conv (kernel_size, inplanes, outplanes; stride = 1 , dilation = 1 ,
12
2
first_dilation = nothing , norm_layer = BatchNorm)
13
3
kernel_size = stride == 1 && dilation == 1 ? (1 , 1 ) : kernel_size
@@ -80,12 +70,15 @@ expansion_factor(::typeof(bottleneck)) = 4
80
70
81
71
function resnet_stem (; stem_type = :default , inchannels = 3 , replace_stem_pool = false ,
82
72
norm_layer = BatchNorm, activation = relu)
83
- @assert stem_type in [:default , :deep , :deep_tiered ] " Stem type must be one of [:default, :deep, :deep_tiered]"
73
+ @assert stem_type in [:default , :deep , :deep_tiered ]
74
+ " Stem type must be one of [:default, :deep, :deep_tiered]"
84
75
# Main stem
85
- inplanes = stem_type == :deep ? stem_width * 2 : 64
86
- if stem_type == :deep
87
- stem_channels = (stem_width, stem_width)
88
- if stem_type == :deep_tiered
76
+ deep_stem = stem_type == :deep || stem_type == :deep_tiered
77
+ inplanes = deep_stem ? stem_width * 2 : 64
78
+ if deep_stem
79
+ if stem_type == :deep
80
+ stem_channels = (stem_width, stem_width)
81
+ elseif stem_type == :deep_tiered
89
82
stem_channels = (3 * (stem_width ÷ 4 ), stem_width)
90
83
end
91
84
conv1 = Chain (Conv ((3 , 3 ), inchannels => stem_channels[0 ]; stride = 2 , pad = 1 ,
@@ -107,7 +100,7 @@ function resnet_stem(; stem_type = :default, inchannels = 3, replace_stem_pool =
107
100
else
108
101
stempool = MaxPool ((3 , 3 ); stride = 2 , pad = 1 )
109
102
end
110
- return inplanes, Chain (conv1, bn1, stempool)
103
+ return Chain (conv1, bn1, stempool), inplanes
111
104
end
112
105
113
106
function downsample_block (downsample_fn, inplanes, planes, expansion; kernel_size = (1 , 1 ),
128
121
# See `basicblock` and `bottleneck` for examples. A block must define a function
129
122
# `expansion(::typeof(block))` that returns the expansion factor of the block.
130
123
function _make_blocks (block_fn, channels, block_repeats, inplanes; output_stride = 32 ,
131
- downsample_fn = downsample_conv, downsample_args:: NamedTuple = (),
132
- drop_block_rate = 0.0 , drop_path_rate = 0.0 ,
133
- block_args:: NamedTuple = ())
124
+ downsample_fn = downsample_conv,
125
+ drop_rates:: NamedTuple , block_args:: NamedTuple )
134
126
@assert output_stride in (8 , 16 , 32 ) " Invalid `output_stride`. Must be one of (8, 16, 32)"
135
127
expansion = expansion_factor (block_fn)
136
128
stages = []
@@ -139,7 +131,7 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
139
131
dilation = prev_dilation = 1
140
132
for (stage_idx, (planes, num_blocks, drop_block)) in enumerate (zip (channels,
141
133
block_repeats,
142
- _drop_blocks (drop_block_rate)))
134
+ _drop_blocks (drop_rates . drop_block_rate)))
143
135
# Stride calculations for each stage
144
136
stride = stage_idx == 1 ? 1 : 2
145
137
if net_stride >= output_stride
@@ -148,16 +140,16 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
148
140
else
149
141
net_stride *= stride
150
142
end
151
- # Downsample block; either a (default) convolution-based block or a pooling-based block.
143
+ # Downsample block; either a (default) convolution-based block or a pooling-based block
152
144
downsample = downsample_block (downsample_fn, inplanes, planes, expansion;
153
- stride, dilation, first_dilation = dilation, downsample_args ... )
145
+ stride, dilation, first_dilation = dilation)
154
146
# Construct the blocks for each stage
155
147
blocks = []
156
148
for block_idx in 1 : num_blocks
157
149
downsample = block_idx == 1 ? downsample : identity
158
150
stride = block_idx == 1 ? stride : 1
159
151
# stochastic depth linear decay rule
160
- block_dpr = drop_path_rate * net_block_idx / (sum (block_repeats) - 1 )
152
+ block_dpr = drop_rates . drop_path_rate * net_block_idx / (sum (block_repeats) - 1 )
161
153
push! (blocks,
162
154
block_fn (inplanes, planes; stride, downsample,
163
155
first_dilation = prev_dilation,
@@ -171,21 +163,26 @@ function _make_blocks(block_fn, channels, block_repeats, inplanes; output_stride
171
163
return Chain (stages... )
172
164
end
173
165
166
+ # returns `DropBlock`s for each block of the ResNet
167
+ function _drop_blocks (drop_block_prob = 0.0 )
168
+ return [
169
+ identity,
170
+ identity,
171
+ DropBlock (drop_block_prob, 5 , 0.25 ),
172
+ DropBlock (drop_block_prob, 3 , 1.00 ),
173
+ ]
174
+ end
175
+
174
176
function resnet (block, layers; nclasses = 1000 , inchannels = 3 , output_stride = 32 ,
175
- stem_fn = resnet_stem, stem_args :: NamedTuple = NamedTuple () ,
176
- downsample_fn = downsample_conv, downsample_args :: NamedTuple = NamedTuple (),
177
+ stem = first ( resnet_stem (; inchannels)), inplanes = 64 ,
178
+ downsample_fn = downsample_conv,
177
179
drop_rates:: NamedTuple = (drop_rate = 0.0 , drop_path_rate = 0.0 ,
178
- drop_block_rate = 0.5 ),
180
+ drop_block_rate = 0.0 ),
179
181
block_args:: NamedTuple = NamedTuple ())
180
- # Stem
181
- inplanes, stem = stem_fn (; inchannels, stem_args... )
182
182
# Feature Blocks
183
183
channels = [64 , 128 , 256 , 512 ]
184
184
stage_blocks = _make_blocks (block, channels, layers, inplanes;
185
- output_stride, downsample_fn, downsample_args,
186
- drop_block_rate = drop_rates. drop_block_rate,
187
- drop_path_rate = drop_rates. drop_path_rate,
188
- block_args)
185
+ output_stride, downsample_fn, drop_rates, block_args)
189
186
# Head (Pooling and Classifier)
190
187
expansion = expansion_factor (block)
191
188
num_features = 512 * expansion
0 commit comments