Skip to content

Commit 674b27e

Browse files
committed
Make templating work again
And expand the lowest level of the ResNet API
1 parent cff07cb commit 674b27e

File tree

2 files changed

+64
-68
lines changed

2 files changed

+64
-68
lines changed

src/convnets/resnets/core.jl

Lines changed: 62 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -132,32 +132,13 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
132132
:C => (downsample_conv, downsample_conv),
133133
:D => (downsample_pool, downsample_identity))
134134

135-
# Makes the downsample `Vector`` with `NTuple{2}`s of functions when it is
136-
# specified as a `Vector` of `Symbol`s. This is used to make the downsample
137-
# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is
138-
# already an `NTuple{2}` of functions, it is returned unchanged.
139-
function _make_downsample_fns(vec::Vector{<:Symbol}, layers)
140-
downs = []
141-
for i in vec
142-
@assert i in keys(shortcut_dict)
143-
"The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))"
144-
push!(downs, shortcut_dict[i])
145-
end
146-
return downs
147-
end
148-
function _make_downsample_fns(sym::Symbol, layers)
149-
@assert sym in keys(shortcut_dict)
150-
"The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))"
151-
return collect(shortcut_dict[sym] for _ in 1:length(layers))
152-
end
153-
_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec
154-
_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers))
155-
156135
# Stride for each block in the ResNet model
157-
get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2
136+
function get_stride(block_idx::Integer, stage_idx::Integer)
137+
return (stage_idx == 1 || block_idx != 1) ? 1 : 2
138+
end
158139

159140
# returns `DropBlock`s for each stage of the ResNet as in timm.
160-
# TODO - add experimental options for DropBlock as part of the API
141+
# TODO - add experimental options for DropBlock as part of the API (#188)
161142
function _drop_blocks(drop_block_rate::AbstractFloat)
162143
return [
163144
identity, identity,
@@ -225,16 +206,24 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
225206
return Chain(conv1, bn1, stempool), inplanes
226207
end
227208

228-
function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu,
229-
norm_layer = BatchNorm, prenorm = false,
209+
# Templating builders for the blocks and the downsampling layers
210+
function template_builder(block_fn; kwargs...)
211+
function (inplanes, planes; _kwargs...)
212+
return block_fn(inplanes, planes; kwargs..., _kwargs...)
213+
end
214+
end
215+
216+
function template_builder(::typeof(basicblock); reduction_factor::Integer = 1,
217+
activation = relu, norm_layer = BatchNorm, prenorm::Bool = false,
230218
attn_fn = planes -> identity, kargs...)
231219
return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor,
232220
activation, norm_layer, prenorm, attn_fn)
233221
end
234222

235-
function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64,
236-
reduction_factor = 1, activation = relu,
237-
norm_layer = BatchNorm, prenorm = false,
223+
function template_builder(::typeof(bottleneck); cardinality::Integer = 1,
224+
base_width::Integer = 64,
225+
reduction_factor::Integer = 1, activation = relu,
226+
norm_layer = BatchNorm, prenorm::Bool = false,
238227
attn_fn = planes -> identity, kargs...)
239228
return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width,
240229
reduction_factor, activation,
@@ -248,30 +237,32 @@ function template_builder(downsample_fn::Union{typeof(downsample_conv),
248237
return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm)
249238
end
250239

251-
function configure_block(block_template, layers::Vector{Int}; expansion,
252-
downsample_templates::Vector, inplanes::Integer = 64,
253-
drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...)
254-
pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers))
255-
blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers))
240+
resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1)
241+
242+
function configure_resnet_block(block_template, expansion, block_repeats::Vector{<:Integer};
243+
stride_fn = get_stride, plane_fn = resnet_planes,
244+
downsample_templates::NTuple{2, Any},
245+
inplanes::Integer = 64,
246+
drop_path_rate = 0.0, drop_block_rate = 0.0, kwargs...)
247+
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
248+
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
256249
# closure over `idxs`
257-
function get_layers(idxs::NTuple{2, Int})
258-
stage_idx, block_idx = idxs
259-
planes = 64 * 2^(stage_idx - 1)
250+
function get_layers(stage_idx::Integer, block_idx::Integer)
251+
planes = plane_fn(stage_idx)
260252
# `get_stride` is a callback that the user can tweak to change the stride of the
261-
# blocks. It defaults to the standard behaviour as in the paper.
262-
stride = get_stride(idxs)
263-
downsample_fns = downsample_templates[stage_idx]
264-
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
265-
downsample_fns[1] : downsample_fns[2]
253+
# blocks. It defaults to the standard behaviour as in the paper
254+
stride = stride_fn(stage_idx, block_idx)
255+
downsample_template = (stride != 1 || inplanes != planes * expansion) ?
256+
downsample_templates[1] : downsample_templates[2]
266257
# DropBlock, DropPath both take in rates based on a linear scaling schedule
267-
schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx
258+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
268259
drop_path = DropPath(pathschedule[schedule_idx])
269260
drop_block = DropBlock(blockschedule[schedule_idx])
270261
block = block_template(inplanes, planes; stride, drop_path, drop_block)
271-
downsample = downsample_fn(inplanes, planes * expansion; stride)
262+
downsample = downsample_template(inplanes, planes * expansion; stride)
272263
# inplanes increases by expansion after each block
273264
inplanes = (planes * expansion)
274-
return ((block, downsample), inplanes)
265+
return block, downsample
275266
end
276267
return get_layers
277268
end
@@ -280,43 +271,48 @@ end
280271
# used by end-users. `block_fn` is a function that returns a single block of the ResNet.
281272
# See `basicblock` and `bottleneck` for examples. A block must define a function
282273
# `expansion(::typeof(block))` that returns the expansion factor of the block.
283-
function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer;
284-
connection = addact, activation = relu, kwargs...)
285-
outplanes = 0
274+
function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
286275
# Construct each stage
287276
stages = []
288277
for (stage_idx, (num_blocks)) in enumerate(block_repeats)
289278
# Construct the blocks for each stage
290-
blocks = []
291-
for block_idx in range(1, num_blocks)
292-
layers, outplanes = get_layers((stage_idx, block_idx))
293-
block = Parallel(connection$activation, layers...)
294-
push!(blocks, block)
295-
end
279+
blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...)
280+
for block_idx in range(1, num_blocks)]
296281
push!(stages, Chain(blocks...))
297282
end
298-
return Chain(stages...), outplanes
283+
return Chain(stages...)
299284
end
300285

301-
function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B;
302-
inchannels::Integer = 3, nclasses::Integer = 1000,
286+
function resnet(connection, get_layers, block_repeats::Vector{<:Integer}, stem, classifier)
287+
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
288+
return Chain(Chain(stem, stage_blocks), classifier)
289+
end
290+
291+
function resnet(block_fn, block_repeats::Vector{<:Integer},
292+
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
293+
imsize::Dims{2} = (256, 256), inchannels::Integer = 3,
303294
stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64,
304-
pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0,
305-
kwargs...)
295+
connection = addact, activation = relu,
296+
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
297+
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
306298
# Configure downsample templates
307-
downsample_vec = _make_downsample_fns(downsample_opt, layers)
308-
downsample_templates = map(x -> template_builder.(x), downsample_vec)
299+
downsample_templates = map(template_builder, downsample_opt)
309300
# Configure block templates
310301
block_template = template_builder(block_fn; kwargs...)
311-
get_layers = configure_block(block_template, layers; inplanes,
312-
downsample_templates,
313-
expansion = expansion_factor(block_fn), kwargs...)
302+
get_layers = configure_resnet_block(block_template, expansion_factor(block_fn),
303+
block_repeats; inplanes, downsample_templates,
304+
kwargs...)
314305
# Build stages of the ResNet
315-
stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...)
306+
stage_blocks = resnet_stages(get_layers, block_repeats, connection$activation)
307+
backbone = Chain(stem, stage_blocks)
316308
# Build the classifier head
317-
classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer,
309+
nfeaturemaps = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)[3]
310+
classifier = create_classifier(nfeaturemaps, nclasses; dropout_rate, pool_layer,
318311
use_conv)
319-
return Chain(Chain(stem, stage_blocks), classifier)
312+
return Chain(backbone, classifier)
313+
end
314+
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
315+
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs...)
320316
end
321317

322318
# block-layer configurations for ResNet-like models

src/layers/drop.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu
2727
If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead.
2828
"""
2929
# TODO add experimental `DropBlock` options from timm such as gaussian noise and
30-
# more precise `DropBlock` to deal with edges.
30+
# more precise `DropBlock` to deal with edges (#188)
3131
function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size,
3232
gamma_scale) where {T}
3333
H, W, _, _ = size(x)
@@ -63,7 +63,7 @@ function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_sc
6363
@assert 0 drop_block_prob 1
6464
"drop_block_prob must be between 0 and 1, got $drop_block_prob"
6565
@assert 0 gamma_scale 1
66-
"gamma_scale must be between 0 and 1, got $gamma_scale"
66+
return "gamma_scale must be between 0 and 1, got $gamma_scale"
6767
end
6868
function _dropblock_checks(x, drop_block_prob, gamma_scale)
6969
throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock."))

0 commit comments

Comments
 (0)