Skip to content

Commit 54ea529

Browse files
committed
Revert "Remove templating for now"
This reverts commit 8c9f73f.
1 parent ca53acb commit 54ea529

File tree

2 files changed

+96
-91
lines changed

2 files changed

+96
-91
lines changed

src/convnets/resnets/core.jl

Lines changed: 94 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,32 @@ Creates a basic ResNet block.
1515
- `downsample`: the downsampling function to use
1616
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
1717
convolution.
18-
- `connection`: the function applied to the output of residual and skip paths in
19-
a block. See [`addact`](#) and [`actadd`](#) for an example.
18+
- `dilation`: the dilation of the second convolution.
19+
- `first_dilation`: the dilation of the first convolution.
2020
- `activation`: the activation function to use.
21+
- `connection`: the function applied to the output of residual and skip paths in
22+
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
23+
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
2124
- `norm_layer`: the normalization layer to use.
2225
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
2326
function and passed in.
2427
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
2528
function and passed in.
2629
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
2730
"""
28-
function basicblock(inplanes::Integer, planes::Integer; downsample_fns,
29-
stride::Integer = 1, reduction_factor::Integer = 1,
30-
connection = addact, activation = relu,
31-
norm_layer = BatchNorm, prenorm::Bool = false,
31+
function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
32+
norm_layer = BatchNorm, prenorm = false,
3233
drop_block = identity, drop_path = identity,
3334
attn_fn = planes -> identity)
34-
expansion = expansion_factor(basicblock)
3535
first_planes = planes ÷ reduction_factor
36-
outplanes = planes * expansion
36+
outplanes = planes * expansion_factor(basicblock)
3737
conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, prenorm,
3838
stride, pad = 1, bias = false)
3939
conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, prenorm,
4040
pad = 1, bias = false)
4141
layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes),
4242
drop_path]
43-
downsample = downsample_block(downsample_fns, inplanes, planes, expansion;
44-
stride, norm_layer, prenorm)
45-
return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...),
46-
downsample)
43+
return Chain(filter!(!=(identity), layers)...)
4744
end
4845
expansion_factor(::typeof(basicblock)) = 1
4946

@@ -66,38 +63,35 @@ Creates a bottleneck ResNet block.
6663
- `base_width`: the number of output feature maps for each convolutional group.
6764
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
6865
convolution.
66+
- `first_dilation`: the dilation of the 3x3 convolution.
6967
- `activation`: the activation function to use.
7068
- `connection`: the function applied to the output of residual and skip paths in
71-
a block. See [`addact`](#) and [`actadd`](#) for an example.
69+
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
70+
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
7271
- `norm_layer`: the normalization layer to use.
7372
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
7473
function and passed in.
7574
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
7675
function and passed in.
7776
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
7877
"""
79-
function bottleneck(inplanes::Integer, planes::Integer; downsample_fns, stride::Integer = 1,
80-
cardinality::Integer = 1, base_width::Integer = 64,
81-
reduction_factor::Integer = 1, connection = addact, activation = relu,
82-
norm_layer = BatchNorm, prenorm::Bool = false,
78+
function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
79+
reduction_factor = 1, activation = relu,
80+
norm_layer = BatchNorm, prenorm = false,
8381
drop_block = identity, drop_path = identity,
8482
attn_fn = planes -> identity)
85-
expansion = expansion_factor(bottleneck)
8683
width = floor(Int, planes * (base_width / 64)) * cardinality
8784
first_planes = width ÷ reduction_factor
88-
outplanes = planes * expansion
85+
outplanes = planes * expansion_factor(bottleneck)
8986
conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, prenorm,
9087
bias = false)
9188
conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, prenorm,
9289
stride, pad = 1, groups = cardinality, bias = false)
9390
conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, prenorm,
9491
bias = false)
95-
downsample = downsample_block(downsample_fns, inplanes, planes, expansion;
96-
stride, norm_layer, prenorm)
9792
layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3...,
9893
attn_fn(outplanes), drop_path]
99-
return Parallel(connection$activation, Chain(filter!(!=(identity), layers)...),
100-
downsample)
94+
return Chain(filter!(!=(identity), layers)...)
10195
end
10296
expansion_factor(::typeof(bottleneck)) = 4
10397

@@ -132,12 +126,6 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
132126
end
133127
end
134128

135-
function downsample_block(downsample_fns, inplanes, planes, expansion; stride, kwargs...)
136-
down_fn = (stride != 1 || inplanes != planes * expansion) ? downsample_fns[1] :
137-
downsample_fns[2]
138-
return down_fn(inplanes, planes * expansion; stride, kwargs...)
139-
end
140-
141129
# Shortcut configurations for the ResNet models
142130
const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
143131
:B => (downsample_conv, downsample_identity),
@@ -148,7 +136,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
148136
# specified as a `Vector` of `Symbol`s. This is used to make the downsample
149137
# `Vector` for the `_make_blocks` function. If the `eltype(::Vector)` is
150138
# already an `NTuple{2}` of functions, it is returned unchanged.
151-
function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats)
139+
function _make_downsample_fns(vec::Vector{<:Symbol}, layers)
152140
downs = []
153141
for i in vec
154142
@assert i in keys(shortcut_dict)
@@ -157,21 +145,19 @@ function _make_downsample_fns(vec::Vector{<:Symbol}, block_repeats)
157145
end
158146
return downs
159147
end
160-
function _make_downsample_fns(sym::Symbol, block_repeats)
148+
function _make_downsample_fns(sym::Symbol, layers)
161149
@assert sym in keys(shortcut_dict)
162150
"The shortcut type must be one of $(sort(collect(keys(shortcut_dict))))"
163-
return collect(shortcut_dict[sym] for _ in 1:length(block_repeats))
151+
return collect(shortcut_dict[sym] for _ in 1:length(layers))
164152
end
165-
_make_downsample_fns(vec::Vector{<:NTuple{2}}, block_repeats) = vec
166-
_make_downsample_fns(tup::NTuple{2}, block_repeats) = [tup for _ in 1:length(block_repeats)]
153+
_make_downsample_fns(vec::Vector{<:NTuple{2}}, layers) = vec
154+
_make_downsample_fns(tup::NTuple{2}, layers) = collect(tup for _ in 1:length(layers))
167155

168156
# Stride for each block in the ResNet model
169-
function get_stride(stage_idx::Integer, block_idx::Integer)
170-
return (stage_idx == 1 || block_idx != 1) ? 1 : 2
171-
end
157+
get_stride(idxs::NTuple{2, Int}) = (idxs[1] == 1 || idxs[2] != 1) ? 1 : 2
172158

173159
# returns `DropBlock`s for each stage of the ResNet as in timm.
174-
# TODO - add experimental options for DropBlock as part of the API (#188)
160+
# TODO - add experimental options for DropBlock as part of the API
175161
function _drop_blocks(drop_block_rate::AbstractFloat)
176162
return [
177163
identity, identity,
@@ -201,7 +187,8 @@ on how to use this function.
201187
shows peformance improvements over the `:deep` stem in some cases.
202188
203189
- `inchannels`: The number of channels in the input.
204-
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
190+
- `replace_pool`: Whether to replace the default 3x3 max pooling layer with a
191+
3x3 convolution with stride 2 and a normalisation layer.
205192
- `norm_layer`: The normalisation layer used in the stem.
206193
- `activation`: The activation function used in the stem.
207194
"""
@@ -232,86 +219,104 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
232219
# Stem pooling
233220
stempool = replace_pool ?
234221
Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer,
235-
prenorm, stride = 2, pad = 1, bias = false)...) :
222+
prenorm,
223+
stride = 2, pad = 1, bias = false)...) :
236224
MaxPool((3, 3); stride = 2, pad = 1)
237225
return Chain(conv1, bn1, stempool), inplanes
238226
end
239227

240-
function block_args(::typeof(basicblock), block_repeats;
241-
downsample_vec, reduction_factor = 1, activation = relu,
242-
norm_layer = BatchNorm, prenorm = false,
243-
drop_path_rate = 0.0, drop_block_rate = 0.0,
244-
attn_fn = planes -> identity)
245-
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
246-
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
247-
function get_layers(stage_idx, block_idx)
248-
stride = get_stride(stage_idx, block_idx)
249-
downsample_fns = downsample_vec[stage_idx]
250-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
251-
drop_path = DropPath(pathschedule[schedule_idx])
252-
drop_block = DropBlock(blockschedule[schedule_idx])
253-
return (; downsample_fns, reduction_factor, stride, activation, norm_layer,
254-
prenorm, drop_path, drop_block, attn_fn)
255-
end
228+
function template_builder(::typeof(basicblock); reduction_factor = 1, activation = relu,
229+
norm_layer = BatchNorm, prenorm = false,
230+
attn_fn = planes -> identity, kargs...)
231+
return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor,
232+
activation, norm_layer, prenorm, attn_fn)
256233
end
257234

258-
function block_args(::typeof(bottleneck), block_repeats;
259-
downsample_vec, cardinality = 1, base_width = 64,
260-
reduction_factor = 1, activation = relu,
261-
norm_layer = BatchNorm, prenorm = false,
262-
drop_block_rate = 0.0, drop_path_rate = 0.0,
263-
attn_fn = planes -> identity)
264-
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
265-
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
266-
function get_layers(stage_idx, block_idx)
267-
stride = get_stride(stage_idx, block_idx)
268-
downsample_fns = downsample_vec[stage_idx]
269-
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
235+
function template_builder(::typeof(bottleneck); cardinality = 1, base_width::Integer = 64,
236+
reduction_factor = 1, activation = relu,
237+
norm_layer = BatchNorm, prenorm = false,
238+
attn_fn = planes -> identity, kargs...)
239+
return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width,
240+
reduction_factor, activation,
241+
norm_layer, prenorm, attn_fn)
242+
end
243+
244+
function template_builder(downsample_fn::Union{typeof(downsample_conv),
245+
typeof(downsample_pool),
246+
typeof(downsample_identity)};
247+
norm_layer = BatchNorm, prenorm = false)
248+
return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm)
249+
end
250+
251+
function configure_block(block_template, layers::Vector{Int}; expansion,
252+
downsample_templates::Vector, inplanes::Integer = 64,
253+
drop_path_rate = 0.0, drop_block_rate = 0.0, kargs...)
254+
pathschedule = linear_scheduler(drop_path_rate; depth = sum(layers))
255+
blockschedule = linear_scheduler(drop_block_rate; depth = sum(layers))
256+
# closure over `idxs`
257+
function get_layers(idxs::NTuple{2, Int})
258+
stage_idx, block_idx = idxs
259+
planes = 64 * 2^(stage_idx - 1)
260+
# `get_stride` is a callback that the user can tweak to change the stride of the
261+
# blocks. It defaults to the standard behaviour as in the paper.
262+
stride = get_stride(idxs)
263+
downsample_fns = downsample_templates[stage_idx]
264+
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
265+
downsample_fns[1] : downsample_fns[2]
266+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
267+
schedule_idx = sum(layers[1:(stage_idx - 1)]) + block_idx
270268
drop_path = DropPath(pathschedule[schedule_idx])
271269
drop_block = DropBlock(blockschedule[schedule_idx])
272-
return (; downsample_fns, reduction_factor, cardinality, base_width, stride,
273-
activation, norm_layer, prenorm, drop_path, drop_block, attn_fn)
270+
block = block_template(inplanes, planes; stride, drop_path, drop_block)
271+
downsample = downsample_fn(inplanes, planes * expansion; stride)
272+
# inplanes increases by expansion after each block
273+
inplanes = (planes * expansion)
274+
return ((block, downsample), inplanes)
274275
end
276+
return get_layers
275277
end
276278

277279
# Makes the main stages of the ResNet model. This is an internal function and should not be
278280
# used by end-users. `block_fn` is a function that returns a single block of the ResNet.
279281
# See `basicblock` and `bottleneck` for examples. A block must define a function
280282
# `expansion(::typeof(block))` that returns the expansion factor of the block.
281-
function resnet_stages(block_fn, block_repeats::Vector{<:Integer}, inplanes::Integer;
282-
kwargs...)
283+
function resnet_stages(get_layers, block_repeats::Vector{Int}, inplanes::Integer;
284+
connection = addact, activation = relu, kwargs...)
285+
outplanes = 0
283286
# Construct each stage
284287
stages = []
285288
for (stage_idx, (num_blocks)) in enumerate(block_repeats)
286-
planes = 64 * 2^(stage_idx - 1)
287-
get_kwargs = block_args(block_fn, block_repeats; kwargs...)
288289
# Construct the blocks for each stage
289290
blocks = []
290291
for block_idx in range(1, num_blocks)
291-
push!(blocks, block_fn(inplanes, planes; get_kwargs(stage_idx, block_idx)...))
292-
inplanes = planes * expansion_factor(block_fn)
292+
layers, outplanes = get_layers((stage_idx, block_idx))
293+
block = Parallel(connection$activation, layers...)
294+
push!(blocks, block)
293295
end
294296
push!(stages, Chain(blocks...))
295297
end
296-
return Chain(stages...)
298+
return Chain(stages...), outplanes
297299
end
298300

299-
function resnet(block_fn, block_repeats::Vector{<:Integer}, downsample_opt = :B;
300-
imsize::Dims{2} = (256, 256), inchannels::Integer = 3,
301+
function resnet(block_fn, layers::Vector{Int}, downsample_opt = :B;
302+
inchannels::Integer = 3, nclasses::Integer = 1000,
301303
stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64,
302-
pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = 0.0,
303-
use_conv_classifier::Bool = false, nclasses::Integer = 1000, kwargs...)
304+
pool_layer = AdaptiveMeanPool((1, 1)), use_conv = false, dropout_rate = 0.0,
305+
kwargs...)
304306
# Configure downsample templates
305-
downsample_vec = _make_downsample_fns(downsample_opt, block_repeats)
307+
downsample_vec = _make_downsample_fns(downsample_opt, layers)
308+
downsample_templates = map(x -> template_builder.(x), downsample_vec)
309+
# Configure block templates
310+
block_template = template_builder(block_fn; kwargs...)
311+
get_layers = configure_block(block_template, layers; inplanes,
312+
downsample_templates,
313+
expansion = expansion_factor(block_fn), kwargs...)
306314
# Build stages of the ResNet
307-
stage_blocks = resnet_stages(block_fn, block_repeats, inplanes; downsample_vec,
308-
kwargs...)
309-
backbone = Chain(stem, stage_blocks)
315+
stage_blocks, num_features = resnet_stages(get_layers, layers, inplanes; kwargs...)
310316
# Build the classifier head
311-
outfeatures = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)
312-
classifier = create_classifier(outfeatures[3], nclasses; dropout_rate, pool_layer,
313-
use_conv = use_conv_classifier)
314-
return Chain(backbone, classifier)
317+
classifier = create_classifier(num_features, nclasses; dropout_rate, pool_layer,
318+
use_conv)
319+
return Chain(Chain(stem, stage_blocks), classifier)
315320
end
316321

317322
# block-layer configurations for ResNet-like models

src/layers/drop.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu
2727
If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead.
2828
"""
2929
# TODO add experimental `DropBlock` options from timm such as gaussian noise and
30-
# more precise `DropBlock` to deal with edges (#188)
30+
# more precise `DropBlock` to deal with edges.
3131
function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size,
3232
gamma_scale) where {T}
3333
H, W, _, _ = size(x)
@@ -63,7 +63,7 @@ function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_sc
6363
@assert 0 drop_block_prob 1
6464
"drop_block_prob must be between 0 and 1, got $drop_block_prob"
6565
@assert 0 gamma_scale 1
66-
return "gamma_scale must be between 0 and 1, got $gamma_scale"
66+
"gamma_scale must be between 0 and 1, got $gamma_scale"
6767
end
6868
function _dropblock_checks(x, drop_block_prob, gamma_scale)
6969
throw(ArgumentError("x must be an array with 4 dimensions (H, W, C, N) for DropBlock."))

0 commit comments

Comments
 (0)