Skip to content

Commit 251b323

Browse files
committed
More uniform mid level API
1. Expose `pretrain` option for all models 2. Make it easier to initialise models with config options at the mid level by providing an additional dispatch 3. Some cleanup + documentation
1 parent f001221 commit 251b323

File tree

16 files changed

+123
-53
lines changed

16 files changed

+123
-53
lines changed

src/convnets/alexnet.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
"""
2-
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
2+
alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
33
44
Create an AlexNet model
55
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
66
77
# Arguments
88
9+
- `dropout_rate`: dropout rate for the classifier
910
- `inchannels`: The number of input channels.
1011
- `nclasses`: the number of output classes
1112
"""
12-
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
13+
function alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
1314
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
1415
MaxPool((3, 3); stride = 2),
1516
Conv((5, 5), 64 => 192, relu; pad = 2),
@@ -19,9 +20,9 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
1920
Conv((3, 3), 256 => 256, relu; pad = 1),
2021
MaxPool((3, 3); stride = 2))
2122
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
22-
Dropout(0.5),
23+
Dropout(dropout_rate),
2324
Dense(256 * 6 * 6, 4096, relu),
24-
Dropout(0.5),
25+
Dropout(dropout_rate),
2526
Dense(4096, 4096, relu),
2627
Dense(4096, nclasses))
2728
return Chain(backbone, classifier)

src/convnets/convnext.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
6868
return Chain(Chain(backbone...), classifier)
6969
end
7070

71+
function convnext(config::Symbol; drop_path_rate = 0.0, layerscale_init = 1.0f-6,
72+
inchannels::Integer = 3, nclasses::Integer = 1000)
73+
return convnext(CONVNEXT_CONFIGS[config]...; drop_path_rate, layerscale_init,
74+
inchannels, nclasses)
75+
end
76+
7177
# Configurations for ConvNeXt models
7278
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7379
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
@@ -76,27 +82,37 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7682
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7783

7884
"""
79-
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
85+
ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
86+
nclasses::Integer = 1000)
8087
8188
Creates a ConvNeXt model.
8289
([reference](https://arxiv.org/abs/2201.03545))
8390
8491
# Arguments
8592
8693
- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
94+
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
8795
- `inchannels`: number of input channels
8896
- `nclasses`: number of output classes
8997
98+
!!! warning
99+
100+
`ConvNeXt` does not currently support pretrained weights.
101+
90102
See also [`Metalhead.convnext`](#).
91103
"""
92104
struct ConvNeXt
93105
layers::Any
94106
end
95107
@functor ConvNeXt
96108

97-
function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
109+
function ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
110+
nclasses::Integer = 1000)
98111
_checkconfig(config, keys(CONVNEXT_CONFIGS))
99-
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
112+
layers = convnext(config; inchannels, nclasses)
113+
if pretrain
114+
layers = load_pretrained(layers, "convnext_$config")
115+
end
100116
return ConvNeXt(layers)
101117
end
102118

src/convnets/efficientnets/core.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer,
2-
scalings::NTuple{2, Real} = (1, 1),
3-
headplanes::Integer = block_configs[end][3] * 4,
4-
norm_layer = BatchNorm, dropout_rate = nothing,
5-
inchannels::Integer = 3, nclasses::Integer = 1000)
1+
function efficientnetcore(block_configs::AbstractVector{<:Tuple}; inplanes::Integer,
2+
scalings::NTuple{2, Real} = (1, 1),
3+
headplanes::Integer = block_configs[end][3] * 4,
4+
norm_layer = BatchNorm, dropout_rate = nothing,
5+
inchannels::Integer = 3, nclasses::Integer = 1000)
66
layers = []
77
# stem of the model
88
inplanes = _round_channels(inplanes * scalings[1])

src/convnets/efficientnets/efficientnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege
5252
nclasses::Integer = 1000)
5353
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
5454
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
55-
layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings,
56-
inchannels, nclasses)
55+
layers = efficientnet_core(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings,
56+
inchannels, nclasses)
5757
if pretrain
5858
loadpretrain!(layers, string("efficientnet-", config))
5959
end

src/convnets/efficientnets/efficientnetv2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false,
5959
inchannels::Integer = 3, nclasses::Integer = 1000)
6060
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
6161
block_configs = EFFNETV2_CONFIGS[config]
62-
layers = efficientnet(block_configs; inplanes = block_configs[1][3],
63-
headplanes = 1280, inchannels, nclasses)
62+
layers = efficientnet_core(block_configs; inplanes = block_configs[1][3],
63+
headplanes = 1280, inchannels, nclasses)
6464
if pretrain
6565
loadpretrain!(layers, string("efficientnetv2-", config))
6666
end

src/convnets/inceptions/googlenet.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
_inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_3x3, pool_proj)
2+
inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_3x3, pool_proj)
33
44
Create an inception module for use in GoogLeNet
55
([reference](https://arxiv.org/abs/1409.4842v1)).
@@ -14,7 +14,7 @@ Create an inception module for use in GoogLeNet
1414
- `out_5x5`: the number of output feature maps for the 5x5 convolution (branch 3)
1515
- `pool_proj`: the number of output feature maps for the pooling projection (branch 4)
1616
"""
17-
function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)
17+
function inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)
1818
branch1 = Chain(Conv((1, 1), inplanes => out_1x1))
1919
branch2 = Chain(Conv((1, 1), inplanes => red_3x3),
2020
Conv((3, 3), red_3x3 => out_3x3; pad = 1))
@@ -42,17 +42,17 @@ function googlenet(; dropout_rate = 0.4, inchannels::Integer = 3, nclasses::Inte
4242
Conv((1, 1), 64 => 64),
4343
Conv((3, 3), 64 => 192; pad = 1),
4444
MaxPool((3, 3); stride = 2, pad = 1),
45-
_inceptionblock(192, 64, 96, 128, 16, 32, 32),
46-
_inceptionblock(256, 128, 128, 192, 32, 96, 64),
45+
inceptionblock(192, 64, 96, 128, 16, 32, 32),
46+
inceptionblock(256, 128, 128, 192, 32, 96, 64),
4747
MaxPool((3, 3); stride = 2, pad = 1),
48-
_inceptionblock(480, 192, 96, 208, 16, 48, 64),
49-
_inceptionblock(512, 160, 112, 224, 24, 64, 64),
50-
_inceptionblock(512, 128, 128, 256, 24, 64, 64),
51-
_inceptionblock(512, 112, 144, 288, 32, 64, 64),
52-
_inceptionblock(528, 256, 160, 320, 32, 128, 128),
48+
inceptionblock(480, 192, 96, 208, 16, 48, 64),
49+
inceptionblock(512, 160, 112, 224, 24, 64, 64),
50+
inceptionblock(512, 128, 128, 256, 24, 64, 64),
51+
inceptionblock(512, 112, 144, 288, 32, 64, 64),
52+
inceptionblock(528, 256, 160, 320, 32, 128, 128),
5353
MaxPool((3, 3); stride = 2, pad = 1),
54-
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
55-
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
54+
inceptionblock(832, 256, 160, 320, 32, 128, 128),
55+
inceptionblock(832, 384, 192, 384, 48, 128, 128))
5656
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate))
5757
end
5858

src/convnets/mobilenets/mnasnet.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
4141
return Chain(Chain(layers...), create_classifier(max_width, nclasses; dropout_rate))
4242
end
4343

44+
function mnasnet(config::Symbol; width_mult::Real = 1, max_width::Integer = 1280,
45+
dropout_rate = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000)
46+
inplanes, block_configs = MNASNET_CONFIGS[config]
47+
return mnasnet(block_configs; width_mult, max_width, dropout_rate, inplanes,
48+
inchannels, nclasses)
49+
end
50+
4451
# Layer configurations for MNasNet
4552
# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest
4653
# k: kernel size
@@ -79,7 +86,8 @@ const MNASNET_CONFIGS = Dict(:B1 => (32,
7986
(mbconv, 5, 32, 6, 2, 4, 4, relu),
8087
(mbconv, 3, 32, 6, 1, 3, 4, relu),
8188
(mbconv, 5, 88, 6, 2, 3, 4, relu),
82-
(mbconv, 3, 144, 6, 1, 1, nothing, relu)]))
89+
(mbconv, 3, 144, 6, 1, 1, nothing, relu),
90+
]))
8391

8492
"""
8593
MNASNet(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false,
@@ -111,8 +119,7 @@ end
111119
function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false,
112120
inchannels::Integer = 3, nclasses::Integer = 1000)
113121
_checkconfig(config, keys(MNASNET_CONFIGS))
114-
inplanes, block_configs = MNASNET_CONFIGS[config]
115-
layers = mnasnet(block_configs; width_mult, inplanes, inchannels, nclasses)
122+
layers = mnasnet(config; width_mult, inchannels, nclasses)
116123
if pretrain
117124
loadpretrain!(layers, "mnasnet$(width_mult)")
118125
end

src/convnets/mobilenets/mobilenetv1.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1,
3737
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
3838
end
3939

40+
function mobilenetv1(width_mult::Real = 1; activation = relu, dropout_rate = nothing,
41+
inchannels::Integer = 3, nclasses::Integer = 1000)
42+
return mobilenetv1(MOBILENETV1_CONFIGS[config]; width_mult, activation,
43+
dropout_rate, inchannels, nclasses)
44+
end
45+
4046
# Layer configurations for MobileNetv1
4147
# f: block function - we use `dwsep_conv_bn` for all blocks
4248
# k: kernel size

src/convnets/mobilenets/mobilenetv2.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real =
4343
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
4444
end
4545

46+
function mobilenetv2(width_mult::Real = 1; max_width::Integer = 1280,
47+
divisor::Integer = 8, inplanes::Integer = 32,
48+
dropout_rate = 0.2, inchannels::Integer = 3,
49+
nclasses::Integer = 1000)
50+
return mobilenetv2(MOBILENETV2_CONFIGS; width_mult, max_width, divisor, inplanes,
51+
dropout_rate, inchannels, nclasses)
52+
end
53+
4654
# Layer configurations for MobileNetv2
4755
# f: block function - we use `mbconv` for all blocks
4856
# k: kernel size

src/convnets/mobilenets/mobilenetv3.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
4646
(hardswish, identity); dropout_rate))
4747
end
4848

49+
function mobilenetv3(config::Symbol; width_mult::Real = 1, dropout_rate = 0.2,
50+
inchannels::Integer = 3, nclasses::Integer = 1000)
51+
max_width = config === :large ? 1280 : 1024
52+
return mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width,
53+
dropout_rate, inchannels, nclasses)
54+
end
55+
4956
# Layer configurations for small and large models for MobileNetv3
5057
# f: mbconv block function - we use `mbconv` for all blocks
5158
# k: kernel size
@@ -110,9 +117,7 @@ end
110117
function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false,
111118
inchannels::Integer = 3, nclasses::Integer = 1000)
112119
_checkconfig(config, [:small, :large])
113-
max_width = config == :large ? 1280 : 1024
114-
layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels,
115-
nclasses)
120+
layers = mobilenetv3(config; width_mult, inchannels, nclasses)
116121
if pretrain
117122
loadpretrain!(layers, string("MobileNetv3", config))
118123
end

0 commit comments

Comments
 (0)