@@ -15,35 +15,32 @@ Creates a basic ResNet block.
15
15
- `downsample`: the downsampling function to use
16
16
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
17
17
convolution.
18
- - `connection `: the function applied to the output of residual and skip paths in
19
- a block. See [`addact`](#) and [`actadd`](#) for an example .
18
+ - `dilation `: the dilation of the second convolution.
19
+ - `first_dilation`: the dilation of the first convolution .
20
20
- `activation`: the activation function to use.
21
+ - `connection`: the function applied to the output of residual and skip paths in
22
+ a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
23
+ PartialFunctions.jl to pass in the activation function with the notation `addact\$ activation`.
21
24
- `norm_layer`: the normalization layer to use.
22
25
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
23
26
function and passed in.
24
27
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
25
28
function and passed in.
26
29
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
27
30
"""
28
- function basicblock (inplanes:: Integer , planes:: Integer ; downsample_fns,
29
- stride:: Integer = 1 , reduction_factor:: Integer = 1 ,
30
- connection = addact, activation = relu,
31
- norm_layer = BatchNorm, prenorm:: Bool = false ,
31
+ function basicblock (inplanes, planes; stride = 1 , reduction_factor = 1 , activation = relu,
32
+ norm_layer = BatchNorm, prenorm = false ,
32
33
drop_block = identity, drop_path = identity,
33
34
attn_fn = planes -> identity)
34
- expansion = expansion_factor (basicblock)
35
35
first_planes = planes ÷ reduction_factor
36
- outplanes = planes * expansion
36
+ outplanes = planes * expansion_factor (basicblock)
37
37
conv_bn1 = conv_norm ((3 , 3 ), inplanes => first_planes, identity; norm_layer, prenorm,
38
38
stride, pad = 1 , bias = false )
39
39
conv_bn2 = conv_norm ((3 , 3 ), first_planes => outplanes, identity; norm_layer, prenorm,
40
40
pad = 1 , bias = false )
41
41
layers = [conv_bn1... , drop_block, activation, conv_bn2... , attn_fn (outplanes),
42
42
drop_path]
43
- downsample = downsample_block (downsample_fns, inplanes, planes, expansion;
44
- stride, norm_layer, prenorm)
45
- return Parallel (connection$ activation, Chain (filter! (!= (identity), layers)... ),
46
- downsample)
43
+ return Chain (filter! (!= (identity), layers)... )
47
44
end
48
45
expansion_factor (:: typeof (basicblock)) = 1
49
46
@@ -66,38 +63,35 @@ Creates a bottleneck ResNet block.
66
63
- `base_width`: the number of output feature maps for each convolutional group.
67
64
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
68
65
convolution.
66
+ - `first_dilation`: the dilation of the 3x3 convolution.
69
67
- `activation`: the activation function to use.
70
68
- `connection`: the function applied to the output of residual and skip paths in
71
- a block. See [`addact`](#) and [`actadd`](#) for an example.
69
+ a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
70
+ PartialFunctions.jl to pass in the activation function with the notation `addact\$ activation`.
72
71
- `norm_layer`: the normalization layer to use.
73
72
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
74
73
function and passed in.
75
74
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
76
75
function and passed in.
77
76
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
78
77
"""
79
- function bottleneck (inplanes:: Integer , planes:: Integer ; downsample_fns, stride:: Integer = 1 ,
80
- cardinality:: Integer = 1 , base_width:: Integer = 64 ,
81
- reduction_factor:: Integer = 1 , connection = addact, activation = relu,
82
- norm_layer = BatchNorm, prenorm:: Bool = false ,
78
+ function bottleneck (inplanes, planes; stride = 1 , cardinality = 1 , base_width = 64 ,
79
+ reduction_factor = 1 , activation = relu,
80
+ norm_layer = BatchNorm, prenorm = false ,
83
81
drop_block = identity, drop_path = identity,
84
82
attn_fn = planes -> identity)
85
- expansion = expansion_factor (bottleneck)
86
83
width = floor (Int, planes * (base_width / 64 )) * cardinality
87
84
first_planes = width ÷ reduction_factor
88
- outplanes = planes * expansion
85
+ outplanes = planes * expansion_factor (bottleneck)
89
86
conv_bn1 = conv_norm ((1 , 1 ), inplanes => first_planes, activation; norm_layer, prenorm,
90
87
bias = false )
91
88
conv_bn2 = conv_norm ((3 , 3 ), first_planes => width, identity; norm_layer, prenorm,
92
89
stride, pad = 1 , groups = cardinality, bias = false )
93
90
conv_bn3 = conv_norm ((1 , 1 ), width => outplanes, identity; norm_layer, prenorm,
94
91
bias = false )
95
- downsample = downsample_block (downsample_fns, inplanes, planes, expansion;
96
- stride, norm_layer, prenorm)
97
92
layers = [conv_bn1... , conv_bn2... , drop_block, activation, conv_bn3... ,
98
93
attn_fn (outplanes), drop_path]
99
- return Parallel (connection$ activation, Chain (filter! (!= (identity), layers)... ),
100
- downsample)
94
+ return Chain (filter! (!= (identity), layers)... )
101
95
end
102
96
expansion_factor (:: typeof (bottleneck)) = 4
103
97
@@ -132,12 +126,6 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
132
126
end
133
127
end
134
128
135
- function downsample_block (downsample_fns, inplanes, planes, expansion; stride, kwargs... )
136
- down_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_fns[1 ] :
137
- downsample_fns[2 ]
138
- return down_fn (inplanes, planes * expansion; stride, kwargs... )
139
- end
140
-
141
129
# Shortcut configurations for the ResNet models
142
130
const shortcut_dict = Dict (:A => (downsample_identity, downsample_identity),
143
131
:B => (downsample_conv, downsample_identity),
@@ -148,7 +136,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
148
136
# specified as a `Vector` of `Symbol`s. This is used to make the downsample
149
137
# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is
150
138
# already an `NTuple{2}` of functions, it is returned unchanged.
151
- function _make_downsample_fns (vec:: Vector{<:Symbol} , block_repeats )
139
+ function _make_downsample_fns (vec:: Vector{<:Symbol} , layers )
152
140
downs = []
153
141
for i in vec
154
142
@assert i in keys (shortcut_dict)
@@ -157,21 +145,19 @@ function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats)
157
145
end
158
146
return downs
159
147
end
160
- function _make_downsample_fns (sym:: Symbol , block_repeats )
148
+ function _make_downsample_fns (sym:: Symbol , layers )
161
149
@assert sym in keys (shortcut_dict)
162
150
" The shortcut type must be one of $(sort (collect (keys (shortcut_dict)))) "
163
- return collect (shortcut_dict[sym] for _ in 1 : length (block_repeats ))
151
+ return collect (shortcut_dict[sym] for _ in 1 : length (layers ))
164
152
end
165
- _make_downsample_fns (vec:: Vector{<:NTuple{2}} , block_repeats ) = vec
166
- _make_downsample_fns (tup:: NTuple{2} , block_repeats ) = [ tup for _ in 1 : length (block_repeats)]
153
+ _make_downsample_fns (vec:: Vector{<:NTuple{2}} , layers ) = vec
154
+ _make_downsample_fns (tup:: NTuple{2} , layers ) = collect ( tup for _ in 1 : length (layers))
167
155
168
156
# Stride for each block in the ResNet model
169
- function get_stride (stage_idx:: Integer , block_idx:: Integer )
170
- return (stage_idx == 1 || block_idx != 1 ) ? 1 : 2
171
- end
157
+ get_stride (idxs:: NTuple{2, Int} ) = (idxs[1 ] == 1 || idxs[2 ] != 1 ) ? 1 : 2
172
158
173
159
# returns `DropBlock`s for each stage of the ResNet as in timm.
174
- # TODO - add experimental options for DropBlock as part of the API (#188)
160
+ # TODO - add experimental options for DropBlock as part of the API
175
161
function _drop_blocks (drop_block_rate:: AbstractFloat )
176
162
return [
177
163
identity, identity,
@@ -201,7 +187,8 @@ on how to use this function.
201
187
shows peformance improvements over the `:deep` stem in some cases.
202
188
203
189
- `inchannels`: The number of channels in the input.
204
- - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
190
+ - `replace_pool`: Whether to replace the default 3x3 max pooling layer with a
191
+ 3x3 convolution with stride 2 and a normalisation layer.
205
192
- `norm_layer`: The normalisation layer used in the stem.
206
193
- `activation`: The activation function used in the stem.
207
194
"""
@@ -232,86 +219,104 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
232
219
# Stem pooling
233
220
stempool = replace_pool ?
234
221
Chain (conv_norm ((3 , 3 ), inplanes => inplanes, activation; norm_layer,
235
- prenorm, stride = 2 , pad = 1 , bias = false )... ) :
222
+ prenorm,
223
+ stride = 2 , pad = 1 , bias = false )... ) :
236
224
MaxPool ((3 , 3 ); stride = 2 , pad = 1 )
237
225
return Chain (conv1, bn1, stempool), inplanes
238
226
end
239
227
240
- function block_args (:: typeof (basicblock), block_repeats;
241
- downsample_vec, reduction_factor = 1 , activation = relu,
242
- norm_layer = BatchNorm, prenorm = false ,
243
- drop_path_rate = 0.0 , drop_block_rate = 0.0 ,
244
- attn_fn = planes -> identity)
245
- pathschedule = linear_scheduler (drop_path_rate; depth = sum (block_repeats))
246
- blockschedule = linear_scheduler (drop_block_rate; depth = sum (block_repeats))
247
- function get_layers (stage_idx, block_idx)
248
- stride = get_stride (stage_idx, block_idx)
249
- downsample_fns = downsample_vec[stage_idx]
250
- schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
251
- drop_path = DropPath (pathschedule[schedule_idx])
252
- drop_block = DropBlock (blockschedule[schedule_idx])
253
- return (; downsample_fns, reduction_factor, stride, activation, norm_layer,
254
- prenorm, drop_path, drop_block, attn_fn)
255
- end
228
+ function template_builder (:: typeof (basicblock); reduction_factor = 1 , activation = relu,
229
+ norm_layer = BatchNorm, prenorm = false ,
230
+ attn_fn = planes -> identity, kargs... )
231
+ return (args... ; kwargs... ) -> basicblock (args... ; kwargs... , reduction_factor,
232
+ activation, norm_layer, prenorm, attn_fn)
256
233
end
257
234
258
- function block_args (:: typeof (bottleneck), block_repeats;
259
- downsample_vec, cardinality = 1 , base_width = 64 ,
260
- reduction_factor = 1 , activation = relu,
261
- norm_layer = BatchNorm, prenorm = false ,
262
- drop_block_rate = 0.0 , drop_path_rate = 0.0 ,
263
- attn_fn = planes -> identity)
264
- pathschedule = linear_scheduler (drop_path_rate; depth = sum (block_repeats))
265
- blockschedule = linear_scheduler (drop_block_rate; depth = sum (block_repeats))
266
- function get_layers (stage_idx, block_idx)
267
- stride = get_stride (stage_idx, block_idx)
268
- downsample_fns = downsample_vec[stage_idx]
269
- schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
235
+ function template_builder (:: typeof (bottleneck); cardinality = 1 , base_width:: Integer = 64 ,
236
+ reduction_factor = 1 , activation = relu,
237
+ norm_layer = BatchNorm, prenorm = false ,
238
+ attn_fn = planes -> identity, kargs... )
239
+ return (args... ; kwargs... ) -> bottleneck (args... ; kwargs... , cardinality, base_width,
240
+ reduction_factor, activation,
241
+ norm_layer, prenorm, attn_fn)
242
+ end
243
+
244
+ function template_builder (downsample_fn:: Union {typeof (downsample_conv),
245
+ typeof (downsample_pool),
246
+ typeof (downsample_identity)};
247
+ norm_layer = BatchNorm, prenorm = false )
248
+ return (args... ; kwargs... ) -> downsample_fn (args... ; kwargs... , norm_layer, prenorm)
249
+ end
250
+
251
+ function configure_block (block_template, layers:: Vector{Int} ; expansion,
252
+ downsample_templates:: Vector , inplanes:: Integer = 64 ,
253
+ drop_path_rate = 0.0 , drop_block_rate = 0.0 , kargs... )
254
+ pathschedule = linear_scheduler (drop_path_rate; depth = sum (layers))
255
+ blockschedule = linear_scheduler (drop_block_rate; depth = sum (layers))
256
+ # closure over `idxs`
257
+ function get_layers (idxs:: NTuple{2, Int} )
258
+ stage_idx, block_idx = idxs
259
+ planes = 64 * 2 ^ (stage_idx - 1 )
260
+ # `get_stride` is a callback that the user can tweak to change the stride of the
261
+ # blocks. It defaults to the standard behaviour as in the paper.
262
+ stride = get_stride (idxs)
263
+ downsample_fns = downsample_templates[stage_idx]
264
+ downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
265
+ downsample_fns[1 ] : downsample_fns[2 ]
266
+ # DropBlock, DropPath both take in rates based on a linear scaling schedule
267
+ schedule_idx = sum (layers[1 : (stage_idx - 1 )]) + block_idx
270
268
drop_path = DropPath (pathschedule[schedule_idx])
271
269
drop_block = DropBlock (blockschedule[schedule_idx])
272
- return (; downsample_fns, reduction_factor, cardinality, base_width, stride,
273
- activation, norm_layer, prenorm, drop_path, drop_block, attn_fn)
270
+ block = block_template (inplanes, planes; stride, drop_path, drop_block)
271
+ downsample = downsample_fn (inplanes, planes * expansion; stride)
272
+ # inplanes increases by expansion after each block
273
+ inplanes = (planes * expansion)
274
+ return ((block, downsample), inplanes)
274
275
end
276
+ return get_layers
275
277
end
276
278
277
279
# Makes the main stages of the ResNet model. This is an internal function and should not be
278
280
# used by end-users. `block_fn` is a function that returns a single block of the ResNet.
279
281
# See `basicblock` and `bottleneck` for examples. A block must define a function
280
282
# `expansion(::typeof(block))` that returns the expansion factor of the block.
281
- function resnet_stages (block_fn, block_repeats:: Vector{<:Integer} , inplanes:: Integer ;
282
- kwargs... )
283
+ function resnet_stages (get_layers, block_repeats:: Vector{Int} , inplanes:: Integer ;
284
+ connection = addact, activation = relu, kwargs... )
285
+ outplanes = 0
283
286
# Construct each stage
284
287
stages = []
285
288
for (stage_idx, (num_blocks)) in enumerate (block_repeats)
286
- planes = 64 * 2 ^ (stage_idx - 1 )
287
- get_kwargs = block_args (block_fn, block_repeats; kwargs... )
288
289
# Construct the blocks for each stage
289
290
blocks = []
290
291
for block_idx in range (1 , num_blocks)
291
- push! (blocks, block_fn (inplanes, planes; get_kwargs (stage_idx, block_idx)... ))
292
- inplanes = planes * expansion_factor (block_fn)
292
+ layers, outplanes = get_layers ((stage_idx, block_idx))
293
+ block = Parallel (connection$ activation, layers... )
294
+ push! (blocks, block)
293
295
end
294
296
push! (stages, Chain (blocks... ))
295
297
end
296
- return Chain (stages... )
298
+ return Chain (stages... ), outplanes
297
299
end
298
300
299
- function resnet (block_fn, block_repeats :: Vector{<:Integer } , downsample_opt = :B ;
300
- imsize :: Dims{2} = ( 256 , 256 ), inchannels :: Integer = 3 ,
301
+ function resnet (block_fn, layers :: Vector{Int } , downsample_opt = :B ;
302
+ inchannels :: Integer = 3 , nclasses :: Integer = 1000 ,
301
303
stem = first (resnet_stem (; inchannels)), inplanes:: Integer = 64 ,
302
- pool_layer = AdaptiveMeanPool ((1 , 1 )), dropout_rate = 0.0 ,
303
- use_conv_classifier :: Bool = false , nclasses :: Integer = 1000 , kwargs... )
304
+ pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv = false , dropout_rate = 0.0 ,
305
+ kwargs... )
304
306
# Configure downsample templates
305
- downsample_vec = _make_downsample_fns (downsample_opt, block_repeats)
307
+ downsample_vec = _make_downsample_fns (downsample_opt, layers)
308
+ downsample_templates = map (x -> template_builder .(x), downsample_vec)
309
+ # Configure block templates
310
+ block_template = template_builder (block_fn; kwargs... )
311
+ get_layers = configure_block (block_template, layers; inplanes,
312
+ downsample_templates,
313
+ expansion = expansion_factor (block_fn), kwargs... )
306
314
# Build stages of the ResNet
307
- stage_blocks = resnet_stages (block_fn, block_repeats, inplanes; downsample_vec,
308
- kwargs... )
309
- backbone = Chain (stem, stage_blocks)
315
+ stage_blocks, num_features = resnet_stages (get_layers, layers, inplanes; kwargs... )
310
316
# Build the classifier head
311
- outfeatures = Flux. outputsize (backbone, (imsize... , inchannels); padbatch = true )
312
- classifier = create_classifier (outfeatures[3 ], nclasses; dropout_rate, pool_layer,
313
- use_conv = use_conv_classifier)
314
- return Chain (backbone, classifier)
317
+ classifier = create_classifier (num_features, nclasses; dropout_rate, pool_layer,
318
+ use_conv)
319
+ return Chain (Chain (stem, stage_blocks), classifier)
315
320
end
316
321
317
322
# block-layer configurations for ResNet-like models
0 commit comments