Skip to content

Commit 5ce0204

Browse files
committed
The one true function
1. Unify MobileNet and EfficientNet APIs into a single lower level builder function 2. Further unification and consolidation of the mid level API 3. Some cleanup
1 parent 251b323 commit 5ce0204

File tree

22 files changed

+243
-415
lines changed

22 files changed

+243
-415
lines changed

src/Metalhead.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,27 @@ using .Layers
2020

2121
# CNN models
2222
## Builders
23-
include("convnets/builders/core.jl")
23+
include("convnets/builders/irmodel.jl")
2424
include("convnets/builders/mbconv.jl")
2525
include("convnets/builders/resblocks.jl")
26+
include("convnets/builders/resnet.jl")
27+
include("convnets/builders/stages.jl")
2628
## AlexNet and VGG
2729
include("convnets/alexnet.jl")
2830
include("convnets/vgg.jl")
2931
## ResNets
3032
include("convnets/resnets/core.jl")
33+
include("convnets/resnets/res2net.jl")
3134
include("convnets/resnets/resnet.jl")
3235
include("convnets/resnets/resnext.jl")
3336
include("convnets/resnets/seresnet.jl")
34-
include("convnets/resnets/res2net.jl")
3537
## Inceptions
3638
include("convnets/inceptions/googlenet.jl")
3739
include("convnets/inceptions/inceptionv3.jl")
3840
include("convnets/inceptions/inceptionv4.jl")
3941
include("convnets/inceptions/inceptionresnetv2.jl")
4042
include("convnets/inceptions/xception.jl")
4143
## EfficientNets
42-
include("convnets/efficientnets/core.jl")
4344
include("convnets/efficientnets/efficientnet.jl")
4445
include("convnets/efficientnets/efficientnetv2.jl")
4546
## MobileNets

src/convnets/builders/irmodel.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
2+
inplanes::Integer = 32, connection = +, activation = relu,
3+
norm_layer = BatchNorm, divisor::Integer = 8,
4+
tail_conv::Bool = true, expanded_classifier::Bool = false,
5+
headplanes::Integer, dropout_rate = nothing,
6+
inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...)
7+
width_mult, _ = scalings
8+
# building first layer
9+
inplanes = _round_channels(inplanes * width_mult, divisor)
10+
layers = []
11+
append!(layers,
12+
conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1,
13+
norm_layer))
14+
# building inverted residual blocks
15+
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
16+
norm_layer, divisor, kwargs...)
17+
append!(layers, cnn_stages(get_layers, block_repeats, connection))
18+
# building last layers
19+
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
20+
if tail_conv
21+
# special case, supported fully only for MobileNetv3
22+
if expanded_classifier
23+
midplanes = _round_channels(outplanes * block_configs[end][4], divisor)
24+
append!(layers,
25+
conv_norm((1, 1), outplanes, midplanes, activation; norm_layer))
26+
classifier = create_classifier(midplanes, headplanes, nclasses,
27+
(hardswish, identity); dropout_rate)
28+
else
29+
append!(layers,
30+
conv_norm((1, 1), outplanes, headplanes, activation; norm_layer))
31+
classifier = create_classifier(headplanes, nclasses; dropout_rate)
32+
end
33+
else
34+
classifier = create_classifier(outplanes, nclasses; dropout_rate)
35+
end
36+
return Chain(Chain(layers...), classifier)
37+
end
38+
39+
function irmodelbuilder(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
40+
return irmodelbuilder((width_mult, 1), block_configs; kwargs...)
41+
end

src/convnets/builders/mbconv.jl

Lines changed: 28 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
2-
width_mult::Real; norm_layer = BatchNorm, kwargs...)
1+
# TODO - potentially make these builders more flexible to specify stuff like
2+
# activation functions and reductions that don't change over the stages
3+
4+
function dwsepconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
5+
stage_idx::Integer, scalings::NTuple{2, Real};
6+
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
7+
width_mult, depth_mult = scalings
38
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
4-
outplanes = _round_channels(outplanes * width_mult)
9+
outplanes = _round_channels(outplanes * width_mult, divisor)
510
if stage_idx != 1
6-
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
11+
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
712
end
813
function get_layers(block_idx::Integer)
914
inplanes = block_idx == 1 ? inplanes : outplanes
@@ -12,13 +17,14 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
1217
stride, pad = SamePad(), norm_layer, kwargs...)...)
1318
return (block,)
1419
end
15-
return get_layers, nrepeats
20+
return get_layers, ceil(Int, nrepeats * depth_mult)
1621
end
22+
_get_builder(::typeof(dwsep_conv_norm)) = dwsepconv_builder
1723

18-
function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
19-
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
20-
divisor::Integer = 8, se_from_explanes::Bool = false,
21-
kwargs...)
24+
function mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
25+
stage_idx::Integer, scalings::NTuple{2, Real};
26+
norm_layer = BatchNorm, divisor::Integer = 8,
27+
se_from_explanes::Bool = false, kwargs...)
2228
width_mult, depth_mult = scalings
2329
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
2430
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
@@ -39,69 +45,31 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
3945
end
4046
return get_layers, ceil(Int, nrepeats * depth_mult)
4147
end
48+
_get_builder(::typeof(mbconv)) = mbconv_builder
4249

43-
function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
44-
width_mult::Real; norm_layer = BatchNorm, kwargs...)
45-
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
46-
norm_layer, kwargs...)
47-
end
48-
49-
function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
50-
norm_layer = BatchNorm, kwargs...)
50+
function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
51+
stage_idx::Integer, scalings::NTuple{2, Real};
52+
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
53+
width_mult, depth_mult = scalings
5154
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
5255
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
56+
outplanes = _round_channels(outplanes * width_mult, divisor)
5357
function get_layers(block_idx::Integer)
5458
inplanes = block_idx == 1 ? inplanes : outplanes
55-
explanes = _round_channels(inplanes * expansion, 8)
59+
explanes = _round_channels(inplanes * expansion, divisor)
5660
stride = block_idx == 1 ? stride : 1
5761
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
5862
norm_layer, stride, kwargs...)
5963
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
6064
end
61-
return get_layers, nrepeats
62-
end
63-
64-
# TODO - these builders need to be more flexible to potentially specify stuff like
65-
# activation functions and reductions that don't change
66-
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
67-
inplanes::Integer, stage_idx::Integer;
68-
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
69-
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
70-
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
71-
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
72-
kwargs...)
73-
end
74-
75-
function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
76-
inplanes::Integer, stage_idx::Integer;
77-
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
78-
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
79-
if isnothing(scalings)
80-
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
81-
kwargs...)
82-
elseif isnothing(width_mult)
83-
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
84-
kwargs...)
85-
else
86-
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
87-
end
88-
end
89-
90-
function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
91-
inplanes::Integer, stage_idx::Integer;
92-
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
93-
width_mult::Union{Nothing, Number} = nothing, norm_layer)
94-
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
95-
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
96-
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
65+
return get_layers, ceil(Int, nrepeats * depth_mult)
9766
end
67+
_get_builder(::typeof(fused_mbconv)) = fused_mbconv_builder
9868

99-
function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
100-
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
101-
width_mult::Union{Nothing, Number} = nothing,
102-
norm_layer = BatchNorm, kwargs...)
103-
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
104-
width_mult, norm_layer, kwargs...)
69+
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
70+
scalings::NTuple{2, Real}; kwargs...)
71+
builders = _get_builder.(first.(block_configs))
72+
bxs = [builders[idx](block_configs, inplanes, idx, scalings; kwargs...)
10573
for idx in eachindex(block_configs)]
10674
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
10775
end

src/convnets/builders/resnet.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function resnetbuilder(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
2+
connection, classifier_fn)
3+
# Build stages of the ResNet
4+
stage_blocks = cnn_stages(get_layers, block_repeats, connection)
5+
backbone = Chain(stem, stage_blocks...)
6+
# Add classifier to the backbone
7+
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
8+
return Chain(backbone, classifier_fn(nfeaturemaps))
9+
end
File renamed without changes.

src/convnets/convnext.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
4646
"`planes` should have exactly one value for each block"
4747
downsample_layers = []
4848
push!(downsample_layers,
49-
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
49+
Chain(conv_norm((4, 4), inchannels, planes[1]; stride = 4,
5050
norm_layer = ChannelLayerNorm)...))
5151
for m in 1:(length(depths) - 1)
5252
push!(downsample_layers,
53-
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
53+
Chain(conv_norm((2, 2), planes[m], planes[m + 1]; stride = 2,
5454
norm_layer = ChannelLayerNorm, revnorm = true)...))
5555
end
5656
stages = []
@@ -111,7 +111,7 @@ function ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3
111111
_checkconfig(config, keys(CONVNEXT_CONFIGS))
112112
layers = convnext(config; inchannels, nclasses)
113113
if pretrain
114-
layers = load_pretrained(layers, "convnext_$config")
114+
layers = loadpretrain!(layers, "convnext_$config")
115115
end
116116
return ConvNeXt(layers)
117117
end

src/convnets/efficientnets/core.jl

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/convnets/efficientnets/efficientnet.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
3131
:b7 => (600, (2.0, 3.1)),
3232
:b8 => (672, (2.2, 3.6)))
3333

34+
function efficientnet(config::Symbol; norm_layer = BatchNorm,
35+
dropout_rate = nothing, inchannels::Integer = 3,
36+
nclasses::Integer = 1000)
37+
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
38+
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
39+
return irmodelbuilder(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
40+
norm_layer, activation = swish,
41+
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
42+
dropout_rate, inchannels, nclasses)
43+
end
44+
3445
"""
3546
EfficientNet(config::Symbol; pretrain::Bool = false)
3647
@@ -50,10 +61,7 @@ end
5061

5162
function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
5263
nclasses::Integer = 1000)
53-
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
54-
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
55-
layers = efficientnet_core(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings,
56-
inchannels, nclasses)
64+
layers = efficientnet(config; inchannels, nclasses)
5765
if pretrain
5866
loadpretrain!(layers, string("efficientnet-", config))
5967
end

src/convnets/efficientnets/efficientnetv2.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish),
3535
(mbconv, 3, 512, 6, 2, 32, 4, swish),
3636
(mbconv, 3, 768, 6, 1, 8, 4, swish)])
3737

38+
function efficientnetv2(config::Symbol; norm_layer = BatchNorm, dropout_rate = nothing,
39+
inchannels::Integer = 3, nclasses::Integer = 1000)
40+
_checkconfig(config, keys(EFFNETV2_CONFIGS))
41+
block_configs = EFFNETV2_CONFIGS[config]
42+
return irmodelbuilder((1, 1), block_configs; activation = swish, norm_layer,
43+
inplanes = block_configs[1][3], headplanes = 1280,
44+
dropout_rate, inchannels, nclasses)
45+
end
46+
3847
"""
3948
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,
4049
inchannels::Integer = 3, nclasses::Integer = 1000)
@@ -57,10 +66,7 @@ end
5766

5867
function EfficientNetv2(config::Symbol; pretrain::Bool = false,
5968
inchannels::Integer = 3, nclasses::Integer = 1000)
60-
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
61-
block_configs = EFFNETV2_CONFIGS[config]
62-
layers = efficientnet_core(block_configs; inplanes = block_configs[1][3],
63-
headplanes = 1280, inchannels, nclasses)
69+
layers = efficientnetv2(config; inchannels, nclasses)
6470
if pretrain
6571
loadpretrain!(layers, string("efficientnetv2-", config))
6672
end

src/convnets/inceptions/xception.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int
3434
end
3535
push!(layers, relu)
3636
append!(layers,
37-
dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false)))
37+
dwsep_conv_norm((3, 3), inc, outc; pad = 1, use_norm = (false, false)))
3838
push!(layers, BatchNorm(outc))
3939
end
4040
layers = start_with_relu ? layers : layers[2:end]
@@ -62,8 +62,8 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ
6262
xception_block(256, 728, 2; stride = 2),
6363
[xception_block(728, 728, 3) for _ in 1:8]...,
6464
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
65-
dwsep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
66-
dwsep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
65+
dwsep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
66+
dwsep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
6767
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate))
6868
end
6969

0 commit comments

Comments
 (0)