Skip to content

Commit 59e1ef4

Browse files
committed
More uniformity + cleanup
1 parent 8ce0dce commit 59e1ef4

30 files changed

+202
-231
lines changed

src/Metalhead.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ include("vit-based/vit.jl")
5656
include("pretrain.jl")
5757

5858
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
59-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
59+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
60+
WideResNet, ResNeXt, SEResNet, SEResNeXt,
6061
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6162
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
6263
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
63-
WideResNet, SEResNet, SEResNeXt,
64-
MLPMixer, ResMLP, gMLP,
65-
ViT,
66-
ConvMixer, ConvNeXt
64+
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6765

6866
# use Flux._big_show to pretty print large models
6967
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,

src/convnets/alexnet.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""
2-
alexnet(; nclasses::Integer = 1000)
2+
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
33
44
Create an AlexNet model
55
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
66
77
# Arguments
88
9+
- `inchannels`: The number of input channels.
910
- `nclasses`: the number of output classes
1011
"""
1112
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
@@ -27,19 +28,23 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
2728
end
2829

2930
"""
30-
AlexNet(; pretrain::Bool = false, nclasses::Integer = 1000)
31+
AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
32+
nclasses::Integer = 1000)
3133
3234
Create a `AlexNet`.
33-
See also [`alexnet`](#).
34-
35-
!!! warning
36-
37-
`AlexNet` does not currently support pretrained weights.
35+
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
3836
3937
# Arguments
4038
4139
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
40+
- `inchannels`: The number of input channels.
4241
- `nclasses`: the number of output classes
42+
43+
!!! warning
44+
45+
`AlexNet` does not currently support pretrained weights.
46+
47+
See also [`alexnet`](#).
4348
"""
4449
struct AlexNet
4550
layers::Any

src/convnets/convmixer.jl

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,28 @@ function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
2626
pad = SamePad())), +),
2727
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2828
for _ in 1:depth]
29-
return Chain(Chain(stem..., Chain(blocks)), create_classifier(planes, nclasses))
29+
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
3030
end
3131

32-
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
33-
:kernel_size => (9, 9),
34-
:patch_size => (7, 7)),
35-
:small => Dict(:planes => 768, :depth => 32,
36-
:kernel_size => (7, 7),
37-
:patch_size => (7, 7)),
38-
:large => Dict(:planes => 1024, :depth => 20,
39-
:kernel_size => (9, 9),
40-
:patch_size => (7, 7)))
32+
const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
33+
(kernel_size = (9, 9),
34+
patch_size = (7, 7))),
35+
:small => ((768, 32),
36+
(kernel_size = (7, 7),
37+
patch_size = (7, 7))),
38+
:large => ((1024, 20),
39+
(kernel_size = (9, 9),
40+
patch_size = (7, 7))))
4141

4242
"""
43-
ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
43+
ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
4444
4545
Creates a ConvMixer model.
4646
([reference](https://arxiv.org/abs/2201.09792))
4747
4848
# Arguments
4949
50-
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
50+
- `config`: the size of the model, either `:base`, `:small` or `:large`
5151
- `inchannels`: The number of channels in the input.
5252
- `nclasses`: number of classes in the output
5353
"""
@@ -56,13 +56,10 @@ struct ConvMixer
5656
end
5757
@functor ConvMixer
5858

59-
function ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
60-
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
61-
planes = CONVMIXER_CONFIGS[mode][:planes]
62-
depth = CONVMIXER_CONFIGS[mode][:depth]
63-
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
64-
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
65-
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, nclasses)
59+
function ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
60+
_checkconfig(config, keys(CONVMIXER_CONFIGS))
61+
layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]...,
62+
inchannels, nclasses)
6663
return ConvMixer(layers)
6764
end
6865

src/convnets/convnext.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init =
2222
end
2323

2424
"""
25-
convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
25+
convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
2626
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
2727
nclasses::Integer = 1000)
2828
@@ -31,27 +31,27 @@ Creates the layers for a ConvNeXt model.
3131
3232
# Arguments
3333
34-
- `inchannels`: number of input channels.
3534
- `depths`: list with configuration for depth of each block
3635
- `planes`: list with configuration for number of output channels in each block
3736
- `drop_path_rate`: Stochastic depth rate.
3837
- `layerscale_init`: Initial value for [`LayerScale`](#)
3938
([reference](https://arxiv.org/abs/2103.17239))
39+
- `inchannels`: number of input channels.
4040
- `nclasses`: number of output classes
4141
"""
42-
function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
42+
function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer};
4343
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
4444
nclasses::Integer = 1000)
4545
@assert length(depths) == length(planes)
4646
"`planes` should have exactly one value for each block"
4747
downsample_layers = []
48-
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
49-
ChannelLayerNorm(planes[1]))
50-
push!(downsample_layers, stem)
48+
push!(downsample_layers,
49+
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
50+
norm_layer = ChannelLayerNorm)...))
5151
for m in 1:(length(depths) - 1)
52-
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
53-
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
54-
push!(downsample_layers, downsample_layer)
52+
push!(downsample_layers,
53+
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
54+
norm_layer = ChannelLayerNorm, revnorm = true)...))
5555
end
5656
stages = []
5757
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
@@ -64,8 +64,7 @@ function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
6464
end
6565
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
6666
classifier = Chain(GlobalMeanPool(), MLUtils.flatten,
67-
LayerNorm(planes[end]),
68-
Dense(planes[end], nclasses))
67+
LayerNorm(planes[end]), Dense(planes[end], nclasses))
6968
return Chain(Chain(backbone...), classifier)
7069
end
7170

@@ -77,13 +76,14 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7776
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7877

7978
"""
80-
ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
79+
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
8180
8281
Creates a ConvNeXt model.
8382
([reference](https://arxiv.org/abs/2201.03545))
8483
8584
# Arguments
8685
86+
- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
8787
- `inchannels`: The number of channels in the input.
8888
- `nclasses`: number of output classes
8989
@@ -94,9 +94,9 @@ struct ConvNeXt
9494
end
9595
@functor ConvNeXt
9696

97-
function ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
98-
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
99-
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, nclasses)
97+
function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
98+
_checkconfig(config, keys(CONVNEXT_CONFIGS))
99+
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
100100
return ConvNeXt(layers)
101101
end
102102

src/convnets/densenet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ Create a DenseNet model
9999
- `reduction`: the factor by which the number of feature maps is scaled across each transition
100100
- `nclasses`: the number of output classes
101101
"""
102-
function densenet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
102+
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
103+
reduction = 0.5,
103104
inchannels::Integer = 3, nclasses::Integer = 1000)
104105
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
105106
reduction, inchannels, nclasses)

src/convnets/efficientnet.jl

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
2222
- `max_width`: maximum number of output channels before the fully connected
2323
classification blocks
2424
"""
25-
function efficientnet(scalings, block_configs; max_width::Integer = 1280,
26-
inchannels::Integer = 3, nclasses::Integer = 1000)
25+
function efficientnet(scalings::NTuple{2, Real},
26+
block_configs::AbstractVector{NTuple{6, Int}};
27+
max_width::Integer = 1280, inchannels::Integer = 3,
28+
nclasses::Integer = 1000)
2729
wscale, dscale = scalings
2830
scalew(w) = wscale 1 ? w : ceil(Int64, wscale * w)
2931
scaled(d) = dscale 1 ? d : ceil(Int64, dscale * d)
@@ -83,61 +85,32 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
8385
:b8 => (672, (2.2, 3.6)))
8486

8587
"""
86-
EfficientNet(scalings, block_configs; max_width::Integer = 1280,
87-
inchannels::Integer = 3, nclasses::Integer = 1000)
88+
EfficientNet(config::Symbol; pretrain::Bool = false)
8889
8990
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
9091
See also [`efficientnet`](#).
9192
9293
# Arguments
9394
94-
- `scalings`: global width and depth scaling (given as a tuple)
95-
96-
- `block_configs`: configuration for each inverted residual block,
97-
given as a vector of tuples with elements:
98-
99-
+ `n`: number of block repetitions (will be scaled by global depth scaling)
100-
+ `k`: kernel size
101-
+ `s`: kernel stride
102-
+ `e`: expansion ratio
103-
+ `i`: block input channels (will be scaled by global width scaling)
104-
+ `o`: block output channels (will be scaled by global width scaling)
105-
- `inchannels`: number of input channels
106-
- `nclasses`: number of output classes
107-
- `max_width`: maximum number of output channels before the fully connected
108-
classification blocks
95+
- `config`: name of default configuration
96+
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
97+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
10998
"""
11099
struct EfficientNet
111100
layers::Any
112101
end
113102
@functor EfficientNet
114103

115-
function EfficientNet(scalings, block_configs; max_width::Integer = 1280,
116-
inchannels::Integer = 3, nclasses::Integer = 1000)
117-
layers = efficientnet(scalings, block_configs; inchannels, nclasses, max_width)
118-
return EfficientNet(layers)
104+
function EfficientNet(config::Symbol; pretrain::Bool = false)
105+
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
106+
model = efficientnet(EFFICIENTNET_GLOBAL_CONFIGS[config][2], EFFICIENTNET_BLOCK_CONFIGS)
107+
if pretrain
108+
loadpretrain!(model, string("efficientnet-", config))
109+
end
110+
return model
119111
end
120112

121113
(m::EfficientNet)(x) = m.layers(x)
122114

123115
backbone(m::EfficientNet) = m.layers[1]
124116
classifier(m::EfficientNet) = m.layers[2]
125-
126-
"""
127-
EfficientNet(name::Symbol; pretrain::Bool = false)
128-
129-
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
130-
See also [`efficientnet`](#).
131-
132-
# Arguments
133-
134-
- `name`: name of default configuration
135-
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
136-
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
137-
"""
138-
function EfficientNet(name::Symbol; pretrain::Bool = false)
139-
_checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS))
140-
model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS)
141-
pretrain && loadpretrain!(model, string("efficientnet-", name))
142-
return model
143-
end

src/convnets/inception/googlenet.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ function googlenet(; inchannels::Integer = 3, nclasses::Integer = 1000)
5353
MaxPool((3, 3); stride = 2, pad = 1),
5454
_inceptionblock(832, 256, 160, 320, 32, 128, 128),
5555
_inceptionblock(832, 384, 192, 384, 48, 128, 128))
56-
classifier = create_classifier(1024, nclasses; dropout_rate = 0.4)
57-
return Chain(backbone, classifier)
56+
return Chain(backbone, create_classifier(1024, nclasses; dropout_rate = 0.4))
5857
end
5958

6059
"""

src/convnets/inception/inceptionresnetv2.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ function inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0,
9292
[block8(0.20f0) for _ in 1:9]...,
9393
block8(; activation = relu),
9494
conv_norm((1, 1), 2080, 1536)...)
95-
classifier = create_classifier(1536, nclasses; dropout_rate)
96-
return Chain(backbone, classifier)
95+
return Chain(backbone, create_classifier(1536, nclasses; dropout_rate))
9796
end
9897

9998
"""

src/convnets/inception/inceptionv3.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ function inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000)
154154
inceptionv3_d(768),
155155
inceptionv3_e(1280),
156156
inceptionv3_e(2048))
157-
classifier = create_classifier(2048, nclasses; dropout_rate = 0.2)
158-
return Chain(backbone, classifier)
157+
return Chain(backbone, create_classifier(2048, nclasses; dropout_rate = 0.2))
159158
end
160159

161160
"""

src/convnets/inception/inceptionv4.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3,
117117
inceptionv4_c(),
118118
inceptionv4_c(),
119119
inceptionv4_c())
120-
classifier = create_classifier(1536, nclasses; dropout_rate)
121-
return Chain(backbone, classifier)
120+
return Chain(backbone, create_classifier(1536, nclasses; dropout_rate))
122121
end
123122

124123
"""

0 commit comments

Comments
 (0)