@@ -132,32 +132,13 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
132
132
:C => (downsample_conv, downsample_conv),
133
133
:D => (downsample_pool, downsample_identity))
134
134
135
- # Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is
136
- # specified as a `Vector` of `Symbol`s. This is used to make the downsample
137
- # `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is
138
- # already an `NTuple{2}` of functions, it is returned unchanged.
139
- function _make_downsample_fns (vec:: Vector{<:Symbol} , layers)
140
- downs = []
141
- for i in vec
142
- @assert i in keys (shortcut_dict)
143
- " The shortcut type must be one of $(sort (collect (keys (shortcut_dict)))) "
144
- push! (downs, shortcut_dict[i])
145
- end
146
- return downs
147
- end
148
- function _make_downsample_fns (sym:: Symbol , layers)
149
- @assert sym in keys (shortcut_dict)
150
- " The shortcut type must be one of $(sort (collect (keys (shortcut_dict)))) "
151
- return collect (shortcut_dict[sym] for _ in 1 : length (layers))
152
- end
153
- _make_downsample_fns (vec:: Vector{<:NTuple{2}} , layers) = vec
154
- _make_downsample_fns (tup:: NTuple{2} , layers) = collect (tup for _ in 1 : length (layers))
155
-
156
135
# Stride for each block in the ResNet model
157
- get_stride (idxs:: NTuple{2, Int} ) = (idxs[1 ] == 1 || idxs[2 ] != 1 ) ? 1 : 2
136
+ function get_stride (block_idx:: Integer , stage_idx:: Integer )
137
+ return (stage_idx == 1 || block_idx != 1 ) ? 1 : 2
138
+ end
158
139
159
140
# returns `DropBlock`s for each stage of the ResNet as in timm.
160
- # TODO - add experimental options for DropBlock as part of the API
141
+ # TODO - add experimental options for DropBlock as part of the API (#188)
161
142
function _drop_blocks (drop_block_rate:: AbstractFloat )
162
143
return [
163
144
identity, identity,
@@ -225,16 +206,24 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
225
206
return Chain (conv1, bn1, stempool), inplanes
226
207
end
227
208
228
- function template_builder (:: typeof (basicblock); reduction_factor = 1 , activation = relu,
229
- norm_layer = BatchNorm, prenorm = false ,
209
+ # Templating builders for the blocks and the downsampling layers
210
+ function template_builder (block_fn; kwargs... )
211
+ function (inplanes, planes; _kwargs... )
212
+ return block_fn (inplanes, planes; kwargs... , _kwargs... )
213
+ end
214
+ end
215
+
216
+ function template_builder (:: typeof (basicblock); reduction_factor:: Integer = 1 ,
217
+ activation = relu, norm_layer = BatchNorm, prenorm:: Bool = false ,
230
218
attn_fn = planes -> identity, kargs... )
231
219
return (args... ; kwargs... ) -> basicblock (args... ; kwargs... , reduction_factor,
232
220
activation, norm_layer, prenorm, attn_fn)
233
221
end
234
222
235
- function template_builder (:: typeof (bottleneck); cardinality = 1 , base_width:: Integer = 64 ,
236
- reduction_factor = 1 , activation = relu,
237
- norm_layer = BatchNorm, prenorm = false ,
223
+ function template_builder (:: typeof (bottleneck); cardinality:: Integer = 1 ,
224
+ base_width:: Integer = 64 ,
225
+ reduction_factor:: Integer = 1 , activation = relu,
226
+ norm_layer = BatchNorm, prenorm:: Bool = false ,
238
227
attn_fn = planes -> identity, kargs... )
239
228
return (args... ; kwargs... ) -> bottleneck (args... ; kwargs... , cardinality, base_width,
240
229
reduction_factor, activation,
@@ -248,30 +237,32 @@ function template_builder(downsample_fn::Union{typeof(downsample_conv),
248
237
return (args... ; kwargs... ) -> downsample_fn (args... ; kwargs... , norm_layer, prenorm)
249
238
end
250
239
251
- function configure_block (block_template, layers:: Vector{Int} ; expansion,
252
- downsample_templates:: Vector , inplanes:: Integer = 64 ,
253
- drop_path_rate = 0.0 , drop_block_rate = 0.0 , kargs... )
254
- pathschedule = linear_scheduler (drop_path_rate; depth = sum (layers))
255
- blockschedule = linear_scheduler (drop_block_rate; depth = sum (layers))
240
+ resnet_planes (stage_idx:: Integer ) = 64 * 2 ^ (stage_idx - 1 )
241
+
242
+ function configure_resnet_block (block_template, expansion, block_repeats:: Vector{<:Integer} ;
243
+ stride_fn = get_stride, plane_fn = resnet_planes,
244
+ downsample_templates:: NTuple{2, Any} ,
245
+ inplanes:: Integer = 64 ,
246
+ drop_path_rate = 0.0 , drop_block_rate = 0.0 , kwargs... )
247
+ pathschedule = linear_scheduler (drop_path_rate; depth = sum (block_repeats))
248
+ blockschedule = linear_scheduler (drop_block_rate; depth = sum (block_repeats))
256
249
# closure over `idxs`
257
- function get_layers (idxs:: NTuple{2, Int} )
258
- stage_idx, block_idx = idxs
259
- planes = 64 * 2 ^ (stage_idx - 1 )
250
+ function get_layers (stage_idx:: Integer , block_idx:: Integer )
251
+ planes = plane_fn (stage_idx)
260
252
# `get_stride` is a callback that the user can tweak to change the stride of the
261
- # blocks. It defaults to the standard behaviour as in the paper.
262
- stride = get_stride (idxs)
263
- downsample_fns = downsample_templates[stage_idx]
264
- downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
265
- downsample_fns[1 ] : downsample_fns[2 ]
253
+ # blocks. It defaults to the standard behaviour as in the paper
254
+ stride = stride_fn (stage_idx, block_idx)
255
+ downsample_template = (stride != 1 || inplanes != planes * expansion) ?
256
+ downsample_templates[1 ] : downsample_templates[2 ]
266
257
# DropBlock, DropPath both take in rates based on a linear scaling schedule
267
- schedule_idx = sum (layers [1 : (stage_idx - 1 )]) + block_idx
258
+ schedule_idx = sum (block_repeats [1 : (stage_idx - 1 )]) + block_idx
268
259
drop_path = DropPath (pathschedule[schedule_idx])
269
260
drop_block = DropBlock (blockschedule[schedule_idx])
270
261
block = block_template (inplanes, planes; stride, drop_path, drop_block)
271
- downsample = downsample_fn (inplanes, planes * expansion; stride)
262
+ downsample = downsample_template (inplanes, planes * expansion; stride)
272
263
# inplanes increases by expansion after each block
273
264
inplanes = (planes * expansion)
274
- return (( block, downsample), inplanes)
265
+ return block, downsample
275
266
end
276
267
return get_layers
277
268
end
@@ -280,43 +271,48 @@ end
280
271
# used by end-users. `block_fn` is a function that returns a single block of the ResNet.
281
272
# See `basicblock` and `bottleneck` for examples. A block must define a function
282
273
# `expansion(::typeof(block))` that returns the expansion factor of the block.
283
- function resnet_stages (get_layers, block_repeats:: Vector{Int} , inplanes:: Integer ;
284
- connection = addact, activation = relu, kwargs... )
285
- outplanes = 0
274
+ function resnet_stages (get_layers, block_repeats:: Vector{<:Integer} , connection)
286
275
# Construct each stage
287
276
stages = []
288
277
for (stage_idx, (num_blocks)) in enumerate (block_repeats)
289
278
# Construct the blocks for each stage
290
- blocks = []
291
- for block_idx in range (1 , num_blocks)
292
- layers, outplanes = get_layers ((stage_idx, block_idx))
293
- block = Parallel (connection$ activation, layers... )
294
- push! (blocks, block)
295
- end
279
+ blocks = [Parallel (connection, get_layers (stage_idx, block_idx)... )
280
+ for block_idx in range (1 , num_blocks)]
296
281
push! (stages, Chain (blocks... ))
297
282
end
298
- return Chain (stages... ), outplanes
283
+ return Chain (stages... )
299
284
end
300
285
301
- function resnet (block_fn, layers:: Vector{Int} , downsample_opt = :B ;
302
- inchannels:: Integer = 3 , nclasses:: Integer = 1000 ,
286
+ function resnet (connection, get_layers, block_repeats:: Vector{<:Integer} , stem, classifier)
287
+ stage_blocks = resnet_stages (get_layers, block_repeats, connection)
288
+ return Chain (Chain (stem, stage_blocks), classifier)
289
+ end
290
+
291
+ function resnet (block_fn, block_repeats:: Vector{<:Integer} ,
292
+ downsample_opt:: NTuple{2, Any} = (downsample_conv, downsample_identity);
293
+ imsize:: Dims{2} = (256 , 256 ), inchannels:: Integer = 3 ,
303
294
stem = first (resnet_stem (; inchannels)), inplanes:: Integer = 64 ,
304
- pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv = false , dropout_rate = 0.0 ,
305
- kwargs... )
295
+ connection = addact, activation = relu,
296
+ pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv:: Bool = false ,
297
+ dropout_rate = 0.0 , nclasses:: Integer = 1000 , kwargs... )
306
298
# Configure downsample templates
307
- downsample_vec = _make_downsample_fns (downsample_opt, layers)
308
- downsample_templates = map (x -> template_builder .(x), downsample_vec)
299
+ downsample_templates = map (template_builder, downsample_opt)
309
300
# Configure block templates
310
301
block_template = template_builder (block_fn; kwargs... )
311
- get_layers = configure_block (block_template, layers; inplanes ,
312
- downsample_templates,
313
- expansion = expansion_factor (block_fn), kwargs... )
302
+ get_layers = configure_resnet_block (block_template, expansion_factor (block_fn) ,
303
+ block_repeats; inplanes, downsample_templates,
304
+ kwargs... )
314
305
# Build stages of the ResNet
315
- stage_blocks, num_features = resnet_stages (get_layers, layers, inplanes; kwargs... )
306
+ stage_blocks = resnet_stages (get_layers, block_repeats, connection$ activation)
307
+ backbone = Chain (stem, stage_blocks)
316
308
# Build the classifier head
317
- classifier = create_classifier (num_features, nclasses; dropout_rate, pool_layer,
309
+ nfeaturemaps = Flux. outputsize (backbone, (imsize... , inchannels); padbatch = true )[3 ]
310
+ classifier = create_classifier (nfeaturemaps, nclasses; dropout_rate, pool_layer,
318
311
use_conv)
319
- return Chain (Chain (stem, stage_blocks), classifier)
312
+ return Chain (backbone, classifier)
313
+ end
314
+ function resnet (block_fn, block_repeats, downsample_opt:: Symbol = :B ; kwargs... )
315
+ return resnet (block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs... )
320
316
end
321
317
322
318
# block-layer configurations for ResNet-like models
0 commit comments