Skip to content

Commit c998951

Browse files
committed
The real hero was block_idx all along
1 parent ea13dd5 commit c998951

File tree

8 files changed

+176
-188
lines changed

8 files changed

+176
-188
lines changed

src/convnets/efficientnets/core.jl

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,61 @@
1-
abstract type _MBConfig end
2-
3-
struct MBConvConfig <: _MBConfig
4-
kernel_size::Dims{2}
5-
inplanes::Integer
6-
outplanes::Integer
7-
expansion::Real
8-
stride::Integer
9-
nrepeats::Integer
10-
end
11-
function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
12-
expansion::Real, stride::Integer, nrepeats::Integer,
13-
width_mult::Real = 1, depth_mult::Real = 1)
1+
function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
2+
stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1),
3+
norm_layer = BatchNorm)
4+
depth_mult, width_mult = scalings
5+
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
146
inplanes = _round_channels(inplanes * width_mult, 8)
157
outplanes = _round_channels(outplanes * width_mult, 8)
16-
nrepeats = ceil(Int, nrepeats * depth_mult)
17-
return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
18-
stride, nrepeats)
8+
function get_layers(block_idx)
9+
inplanes = block_idx == 1 ? inplanes : outplanes
10+
explanes = _round_channels(inplanes * expansion, 8)
11+
stride = block_idx == 1 ? stride : 1
12+
block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer,
13+
stride, reduction = 4)
14+
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
15+
end
16+
return get_layers, ceil(Int, nrepeats * depth_mult)
1917
end
2018

21-
function efficientnetblock(m::MBConvConfig, norm_layer)
22-
layers = []
23-
explanes = _round_channels(m.inplanes * m.expansion, 8)
24-
push!(layers,
25-
mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer,
26-
stride = m.stride, reduction = 4))
27-
explanes = _round_channels(m.outplanes * m.expansion, 8)
28-
append!(layers,
29-
[mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer,
30-
stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)])
31-
return Chain(layers...)
19+
function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
20+
stage_idx::Integer; norm_layer = BatchNorm)
21+
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
22+
function get_layers(block_idx)
23+
inplanes = block_idx == 1 ? inplanes : outplanes
24+
explanes = _round_channels(inplanes * expansion, 8)
25+
stride = block_idx == 1 ? stride : 1
26+
block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish;
27+
norm_layer, stride)
28+
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
29+
end
30+
return get_layers, nrepeats
3231
end
3332

34-
struct FusedMBConvConfig <: _MBConfig
35-
kernel_size::Dims{2}
36-
inplanes::Integer
37-
outplanes::Integer
38-
expansion::Real
39-
stride::Integer
40-
nrepeats::Integer
41-
end
42-
function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
43-
expansion::Real, stride::Integer, nrepeats::Integer)
44-
return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
45-
stride, nrepeats)
46-
end
47-
48-
function efficientnetblock(m::FusedMBConvConfig, norm_layer)
49-
layers = []
50-
explanes = _round_channels(m.inplanes * m.expansion, 8)
51-
push!(layers,
52-
fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish;
53-
norm_layer, stride = m.stride))
54-
explanes = _round_channels(m.outplanes * m.expansion, 8)
55-
append!(layers,
56-
[fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish;
57-
norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)])
58-
return Chain(layers...)
33+
function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}},
34+
residual_fns::AbstractVector;
35+
scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm)
36+
bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer)
37+
for (stage_idx, residual_fn) in enumerate(residual_fns)]
38+
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
5939
end
6040

61-
function efficientnet(block_configs::AbstractVector{<:_MBConfig};
62-
headplanes::Union{Nothing, Integer} = nothing,
41+
function efficientnet(block_configs::AbstractVector{NTuple{6, Int}},
42+
residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1),
43+
headplanes::Integer = _round_channels(block_configs[end][3] *
44+
scalings[2], 8) * 4,
6345
norm_layer = BatchNorm, dropout_rate = nothing,
6446
inchannels::Integer = 3, nclasses::Integer = 1000)
6547
layers = []
6648
# stem of the model
6749
append!(layers,
68-
conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer,
50+
conv_norm((3, 3), inchannels, block_configs[1][2], swish; norm_layer,
6951
stride = 2, pad = SamePad()))
7052
# building inverted residual blocks
71-
append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs])
53+
get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns;
54+
scalings, norm_layer)
55+
append!(layers, resnet_stages(get_layers, block_repeats, +))
7256
# building last layers
73-
outplanes = block_configs[end].outplanes
74-
headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes
7557
append!(layers,
76-
conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad()))
58+
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[2], 8),
59+
headplanes, swish; pad = SamePad()))
7760
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
7861
end

src/convnets/efficientnets/efficientnet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ const EFFICIENTNET_BLOCK_CONFIGS = [
99
(5, 112, 192, 6, 2, 4),
1010
(3, 192, 320, 6, 1, 1),
1111
]
12-
1312
# Data is organised as (r, (w, d))
1413
# r: image resolution
1514
# w: width scaling
@@ -44,9 +43,10 @@ end
4443
function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
4544
nclasses::Integer = 1000)
4645
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
47-
cfg_fn = (args...) -> MBConvConfig(args..., EFFICIENTNET_GLOBAL_CONFIGS[config][2]...)
48-
block_configs = [cfg_fn(args...) for args in EFFICIENTNET_BLOCK_CONFIGS]
49-
layers = efficientnet(block_configs; inchannels, nclasses)
46+
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
47+
layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS,
48+
fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS));
49+
scalings, inchannels, nclasses)
5050
if pretrain
5151
loadpretrain!(layers, string("efficientnet-", config))
5252
end

src/convnets/efficientnets/efficientnetv2.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
11
# block configs for EfficientNetv2
2+
# data organised as (k, i, o, e, s, n)
23
const EFFNETV2_CONFIGS = Dict(:small => [
3-
FusedMBConvConfig(3, 24, 24, 1, 1, 2),
4-
FusedMBConvConfig(3, 24, 48, 4, 2, 4),
5-
FusedMBConvConfig(3, 48, 64, 4, 2, 4),
6-
MBConvConfig(3, 64, 128, 4, 2, 6),
7-
MBConvConfig(3, 128, 160, 6, 1, 9),
8-
MBConvConfig(3, 160, 256, 6, 2, 15)],
4+
(3, 24, 24, 1, 1, 2),
5+
(3, 24, 48, 4, 2, 4),
6+
(3, 48, 64, 4, 2, 4),
7+
(3, 64, 128, 4, 2, 6),
8+
(3, 128, 160, 6, 1, 9),
9+
(3, 160, 256, 6, 2, 15)],
910
:medium => [
10-
FusedMBConvConfig(3, 24, 24, 1, 1, 3),
11-
FusedMBConvConfig(3, 24, 48, 4, 2, 5),
12-
FusedMBConvConfig(3, 48, 80, 4, 2, 5),
13-
MBConvConfig(3, 80, 160, 4, 2, 7),
14-
MBConvConfig(3, 160, 176, 6, 1, 14),
15-
MBConvConfig(3, 176, 304, 6, 2, 18),
16-
MBConvConfig(3, 304, 512, 6, 1, 5)],
11+
(3, 24, 24, 1, 1, 3),
12+
(3, 24, 48, 4, 2, 5),
13+
(3, 48, 80, 4, 2, 5),
14+
(3, 80, 160, 4, 2, 7),
15+
(3, 160, 176, 6, 1, 14),
16+
(3, 176, 304, 6, 2, 18),
17+
(3, 304, 512, 6, 1, 5)],
1718
:large => [
18-
FusedMBConvConfig(3, 32, 32, 1, 1, 4),
19-
FusedMBConvConfig(3, 32, 64, 4, 2, 7),
20-
FusedMBConvConfig(3, 64, 96, 4, 2, 7),
21-
MBConvConfig(3, 96, 192, 4, 2, 10),
22-
MBConvConfig(3, 192, 224, 6, 1, 19),
23-
MBConvConfig(3, 224, 384, 6, 2, 25),
24-
MBConvConfig(3, 384, 640, 6, 1, 7)],
19+
(3, 32, 32, 1, 1, 4),
20+
(3, 32, 64, 4, 2, 7),
21+
(3, 64, 96, 4, 2, 7),
22+
(3, 96, 192, 4, 2, 10),
23+
(3, 192, 224, 6, 1, 19),
24+
(3, 224, 384, 6, 2, 25),
25+
(3, 384, 640, 6, 1, 7)],
2526
:xlarge => [
26-
FusedMBConvConfig(3, 32, 32, 1, 1, 4),
27-
FusedMBConvConfig(3, 32, 64, 4, 2, 8),
28-
FusedMBConvConfig(3, 64, 96, 4, 2, 8),
29-
MBConvConfig(3, 96, 192, 4, 2, 16),
30-
MBConvConfig(3, 192, 224, 6, 1, 24),
31-
MBConvConfig(3, 384, 512, 6, 2, 32),
32-
MBConvConfig(3, 512, 768, 6, 1, 8)])
27+
(3, 32, 32, 1, 1, 4),
28+
(3, 32, 64, 4, 2, 8),
29+
(3, 64, 96, 4, 2, 8),
30+
(3, 96, 192, 4, 2, 16),
31+
(3, 192, 224, 6, 1, 24),
32+
(3, 384, 512, 6, 2, 32),
33+
(3, 512, 768, 6, 1, 8)])
3334

3435
"""
3536
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,
@@ -54,8 +55,9 @@ end
5455
function EfficientNetv2(config::Symbol; pretrain::Bool = false,
5556
inchannels::Integer = 3, nclasses::Integer = 1000)
5657
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
57-
layers = efficientnet(EFFNETV2_CONFIGS[config]; headplanes = 1280, inchannels,
58-
nclasses)
58+
layers = efficientnet(EFFNETV2_CONFIGS[config],
59+
vcat(fill(fused_mbconv_builder, 3), fill(mbconv_builder, 4));
60+
headplanes = 1280, inchannels, nclasses)
5961
if pretrain
6062
loadpretrain!(layers, string("efficientnetv2"))
6163
end

src/convnets/resnets/core.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
203203
drop_block_rate = nothing, drop_path_rate = nothing,
204204
stride_fn = resnet_stride, planes_fn = resnet_planes,
205205
downsample_tuple = (downsample_conv, downsample_identity))
206+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
207+
# Also get `planes_vec` needed for block `inplanes` and `planes` calculations
206208
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
207209
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
208210
planes_vec = collect(planes_fn(block_repeats))
@@ -265,22 +267,26 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
265267
return get_layers
266268
end
267269

270+
# TODO @theabhirath figure out a better name and potentially refactor other CNNs to use this
268271
function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection)
269272
# Construct each stage
270273
stages = []
271-
for (stage_idx, num_blocks) in enumerate(block_repeats)
274+
for (stage_idx, nblocks) in enumerate(block_repeats)
272275
# Construct the blocks for each stage
273-
blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...)
274-
for block_idx in 1:num_blocks]
276+
blocks = map(1:nblocks) do block_idx
277+
branches = get_layers(stage_idx, block_idx)
278+
return (length(branches) == 1) ? only(branches) :
279+
Parallel(connection, branches...)
280+
end
275281
push!(stages, Chain(blocks...))
276282
end
277283
return Chain(stages...)
278284
end
279285

280-
function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
286+
function resnet(img_dims, stem, builders, block_repeats::AbstractVector{<:Integer},
281287
connection, classifier_fn)
282288
# Build stages of the ResNet
283-
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
289+
stage_blocks = resnet_stages(builders, block_repeats, connection)
284290
backbone = Chain(stem, stage_blocks)
285291
# Add classifier to the backbone
286292
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
@@ -302,39 +308,37 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
302308
if block_type == basicblock
303309
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
304310
@assert base_width==64 "Base width must be 64 for `basicblock`"
305-
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
306-
activation, norm_layer, revnorm, attn_fn,
307-
drop_block_rate, drop_path_rate,
308-
stride_fn = resnet_stride,
309-
planes_fn = resnet_planes,
310-
downsample_tuple = downsample_opt,
311-
kwargs...)
311+
builder = basicblock_builder(block_repeats; inplanes, reduction_factor,
312+
activation, norm_layer, revnorm, attn_fn,
313+
drop_block_rate, drop_path_rate,
314+
stride_fn = resnet_stride, planes_fn = resnet_planes,
315+
downsample_tuple = downsample_opt, kwargs...)
312316
elseif block_type == bottleneck
313-
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
314-
reduction_factor, activation, norm_layer, revnorm,
315-
attn_fn, drop_block_rate, drop_path_rate,
316-
stride_fn = resnet_stride,
317-
planes_fn = resnet_planes,
318-
downsample_tuple = downsample_opt,
319-
kwargs...)
317+
builder = bottleneck_builder(block_repeats; inplanes, cardinality,
318+
base_width, reduction_factor, activation, norm_layer,
319+
revnorm, attn_fn, drop_block_rate, drop_path_rate,
320+
stride_fn = resnet_stride, planes_fn = resnet_planes,
321+
downsample_tuple = downsample_opt, kwargs...)
320322
elseif block_type == bottle2neck
321-
@assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
322-
@assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing"
323-
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1"
323+
@assert isnothing(drop_block_rate)
324+
"DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
325+
@assert isnothing(drop_path_rate)
326+
"DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing"
327+
@assert reduction_factor == 1
328+
"Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1"
324329
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
325330
activation, norm_layer, revnorm, attn_fn,
326331
stride_fn = resnet_stride,
327332
planes_fn = resnet_planes,
328-
downsample_tuple = downsample_opt,
329-
kwargs...)
333+
downsample_tuple = downsample_opt, kwargs...)
330334
else
331335
# TODO: write better message when we have link to dev docs for resnet
332336
throw(ArgumentError("Unknown block type $block_type"))
333337
end
334338
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
335339
pool_layer, use_conv)
336-
return resnet((imsize..., inchannels), stem, get_layers, block_repeats,
337-
connection$activation, classifier_fn)
340+
return resnet((imsize..., inchannels), stem, fill(builder, length(block_repeats)),
341+
block_repeats, connection$activation, classifier_fn)
338342
end
339343
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
340344
return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...)

src/convnets/resnets/res2net.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ function bottle2neck_builder(block_repeats::AbstractVector{<:Integer};
5454
attn_fn = planes -> identity,
5555
stride_fn = resnet_stride, planes_fn = resnet_planes,
5656
downsample_tuple = (downsample_conv, downsample_identity))
57-
planes_vec = collect(planes_fn(block_repeats))
5857
# closure over `idxs`
5958
function get_layers(stage_idx::Integer, block_idx::Integer)
6059
# This is needed for block `inplanes` and `planes` calculations

src/layers/Layers.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ include("attention.jl")
1919
export MHAttention
2020

2121
include("conv.jl")
22-
export conv_norm, basic_conv_bn, dwsep_conv_bn, mbconv, fused_mbconv
22+
export conv_norm, basic_conv_bn, dwsep_conv_bn
2323

2424
include("drop.jl")
2525
export DropBlock, DropPath
2626

2727
include("embeddings.jl")
2828
export PatchEmbedding, ViPosEmbedding, ClassTokens
2929

30+
include("mbconv.jl")
31+
export mbconv, fused_mbconv
32+
3033
include("mlp.jl")
3134
export mlp_block, gated_mlp_block
3235

0 commit comments

Comments
 (0)