Skip to content

Commit e9306c3

Browse files
committed
Expose inchannels and nclasses for every model
Also a. more type annotations b. Expose only configurations vital to the model API in terms of pretraining at the highest level
1 parent cd486df commit e9306c3

34 files changed

+363
-347
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ jobs:
3434
- '"Inception"'
3535
- '"DenseNet"'
3636
- '["ConvNeXt", "ConvMixer"]'
37-
- 'r"ViTs"'
38-
- 'r"Mixers"'
37+
- '[r"ViTs", r"Mixers"]'
3938
steps:
4039
- uses: actions/checkout@v2
4140
- uses: julia-actions/setup-julia@v1

src/convnets/alexnet.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
alexnet(; nclasses = 1000)
2+
alexnet(; nclasses::Integer = 1000)
33
44
Create an AlexNet model
55
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
@@ -8,8 +8,8 @@ Create an AlexNet model
88
99
- `nclasses`: the number of output classes
1010
"""
11-
function alexnet(; nclasses = 1000)
12-
layers = Chain(Chain(Conv((11, 11), 3 => 64, relu; stride = (4, 4), pad = (2, 2)),
11+
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
12+
layers = Chain(Chain(Conv((11, 11), inchannels => 64, relu; stride = (4, 4), pad = (2, 2)),
1313
MaxPool((3, 3); stride = (2, 2)),
1414
Conv((5, 5), 64 => 192, relu; pad = (2, 2)),
1515
MaxPool((3, 3); stride = (2, 2)),
@@ -28,7 +28,7 @@ function alexnet(; nclasses = 1000)
2828
end
2929

3030
"""
31-
AlexNet(; pretrain = false, nclasses = 1000)
31+
AlexNet(; pretrain::Bool = false, nclasses::Integer = 1000)
3232
3333
Create a `AlexNet`.
3434
See also [`alexnet`](#).
@@ -47,8 +47,8 @@ struct AlexNet
4747
end
4848
@functor AlexNet
4949

50-
function AlexNet(; pretrain = false, nclasses = 1000)
51-
layers = alexnet(; nclasses = nclasses)
50+
function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
51+
layers = alexnet(; inchannels, nclasses)
5252
if pretrain
5353
loadpretrain!(layers, "AlexNet")
5454
end

src/convnets/convmixer.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), patch_size::Dims{2} = 7,
3-
activation = gelu, nclasses = 1000)
2+
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
3+
patch_size::Dims{2} = (7, 7), activation = gelu,
4+
inchannels::Integer = 3, nclasses::Integer = 1000)
45
56
Creates a ConvMixer model.
67
([reference](https://arxiv.org/abs/2201.09792))
@@ -9,14 +10,15 @@ Creates a ConvMixer model.
910
1011
- `planes`: number of planes in the output of each block
1112
- `depth`: number of layers
12-
- `inchannels`: The number of channels in the input.
1313
- `kernel_size`: kernel size of the convolutional layers
1414
- `patch_size`: size of the patches
1515
- `activation`: activation function used after the convolutional layers
16+
- `inchannels`: The number of channels in the input.
1617
- `nclasses`: number of classes in the output
1718
"""
18-
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
19-
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
19+
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
20+
patch_size::Dims{2} = (7, 7), activation = gelu,
21+
inchannels::Integer = 3, nclasses::Integer = 1000)
2022
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
2123
stride = patch_size[1])
2224
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
@@ -39,7 +41,7 @@ const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
3941
:patch_size => (7, 7)))
4042

4143
"""
42-
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
44+
ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
4345
4446
Creates a ConvMixer model.
4547
([reference](https://arxiv.org/abs/2201.09792))
@@ -48,22 +50,20 @@ Creates a ConvMixer model.
4850
4951
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
5052
- `inchannels`: The number of channels in the input.
51-
- `activation`: activation function used after the convolutional layers
5253
- `nclasses`: number of classes in the output
5354
"""
5455
struct ConvMixer
5556
layers::Any
5657
end
5758
@functor ConvMixer
5859

59-
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
60+
function ConvMixer(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
6061
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
6162
planes = CONVMIXER_CONFIGS[mode][:planes]
6263
depth = CONVMIXER_CONFIGS[mode][:depth]
6364
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
6465
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
65-
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
66-
nclasses)
66+
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, nclasses)
6767
return ConvMixer(layers)
6868
end
6969

src/convnets/convnext.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
2+
convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6)
33
44
Creates a single block of ConvNeXt.
55
([reference](https://arxiv.org/abs/2201.03545))
@@ -8,21 +8,23 @@ Creates a single block of ConvNeXt.
88
99
- `planes`: number of input channels.
1010
- `drop_path_rate`: Stochastic depth rate.
11-
- `λ`: Initial value for [`LayerScale`](#)
11+
- `layerscale_init`: Initial value for [`LayerScale`](#)
1212
"""
13-
function convnextblock(planes, drop_path_rate = 0.0, λ = 1.0f-6)
13+
function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6)
1414
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
1515
swapdims((3, 1, 2, 4)),
1616
LayerNorm(planes; ϵ = 1.0f-6),
1717
mlp_block(planes, 4 * planes),
18-
LayerScale(planes, λ),
18+
LayerScale(planes, layerscale_init),
1919
swapdims((2, 3, 1, 4)),
2020
DropPath(drop_path_rate)), +)
2121
return layers
2222
end
2323

2424
"""
25-
convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
25+
convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
26+
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
27+
nclasses::Integer = 1000)
2628
2729
Creates the layers for a ConvNeXt model.
2830
([reference](https://arxiv.org/abs/2201.03545))
@@ -33,12 +35,13 @@ Creates the layers for a ConvNeXt model.
3335
- `depths`: list with configuration for depth of each block
3436
- `planes`: list with configuration for number of output channels in each block
3537
- `drop_path_rate`: Stochastic depth rate.
36-
- `λ`: Initial value for [`LayerScale`](#)
38+
- `layerscale_init`: Initial value for [`LayerScale`](#)
3739
([reference](https://arxiv.org/abs/2103.17239))
3840
- `nclasses`: number of output classes
3941
"""
40-
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
41-
nclasses = 1000)
42+
function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
43+
drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
44+
nclasses::Integer = 1000)
4245
@assert length(depths) == length(planes)
4346
"`planes` should have exactly one value for each block"
4447
downsample_layers = []
@@ -54,7 +57,9 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0
5457
dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths))
5558
cur = 0
5659
for i in eachindex(depths)
57-
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
60+
push!(stages,
61+
[convnextblock(planes[i], dp_rates[cur + j], layerscale_init)
62+
for j in 1:depths[i]])
5863
cur += depths[i]
5964
end
6065
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
@@ -72,30 +77,27 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7277
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
7378
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
7479

75-
struct ConvNeXt
76-
layers::Any
77-
end
78-
@functor ConvNeXt
79-
8080
"""
81-
ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
81+
ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
8282
8383
Creates a ConvNeXt model.
8484
([reference](https://arxiv.org/abs/2201.03545))
8585
8686
# Arguments
8787
8888
- `inchannels`: The number of channels in the input.
89-
- `drop_path_rate`: Stochastic depth rate.
90-
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
9189
- `nclasses`: number of output classes
9290
9391
See also [`Metalhead.convnext`](#).
9492
"""
95-
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
96-
nclasses = 1000)
93+
struct ConvNeXt
94+
layers::Any
95+
end
96+
@functor ConvNeXt
97+
98+
function ConvNeXt(mode::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
9799
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
98-
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
100+
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, nclasses)
99101
return ConvNeXt(layers)
100102
end
101103

src/convnets/densenet.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Create a Densenet bottleneck layer
1010
- `outplanes`: number of output feature maps on bottleneck branch
1111
(and scaling factor for inner feature maps; see ref)
1212
"""
13-
function dense_bottleneck(inplanes, outplanes)
13+
function dense_bottleneck(inplanes::Integer, outplanes::Integer)
1414
inner_channels = 4 * outplanes
1515
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false,
1616
revnorm = true)...,
@@ -30,7 +30,7 @@ Create a DenseNet transition sequence
3030
- `inplanes`: number of input feature maps
3131
- `outplanes`: number of output feature maps
3232
"""
33-
function transition(inplanes, outplanes)
33+
function transition(inplanes::Integer, outplanes::Integer)
3434
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)...,
3535
MeanPool((2, 2)))
3636
end
@@ -48,14 +48,14 @@ the number of output feature maps by `growth_rates` with each block
4848
- `growth_rates`: the growth (additive) rates of output feature maps
4949
after each block (a vector of `k`s from the ref)
5050
"""
51-
function dense_block(inplanes, growth_rates)
51+
function dense_block(inplanes::Integer, growth_rates)
5252
return [dense_bottleneck(i, o)
5353
for (i, o) in zip(inplanes .+ cumsum([0, growth_rates[1:(end - 1)]...]),
5454
growth_rates)]
5555
end
5656

5757
"""
58-
densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
58+
densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000)
5959
6060
Create a DenseNet model
6161
([reference](https://arxiv.org/abs/1608.06993)).
@@ -68,9 +68,11 @@ Create a DenseNet model
6868
- `reduction`: the factor by which the number of feature maps is scaled across each transition
6969
- `nclasses`: the number of output classes
7070
"""
71-
function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
71+
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3,
72+
nclasses::Integer = 1000)
7273
layers = []
73-
append!(layers, conv_norm((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
74+
append!(layers,
75+
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false))
7476
push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1)))
7577
outplanes = 0
7678
for (i, rates) in enumerate(growth_rates)
@@ -88,7 +90,7 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
8890
end
8991

9092
"""
91-
densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000)
93+
densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses::Integer = 1000)
9294
9395
Create a DenseNet model
9496
([reference](https://arxiv.org/abs/1608.06993)).
@@ -100,15 +102,15 @@ Create a DenseNet model
100102
- `reduction`: the factor by which the number of feature maps is scaled across each transition
101103
- `nclasses`: the number of output classes
102104
"""
103-
function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5,
104-
nclasses = 1000) where {N}
105+
function densenet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
106+
inchannels::Integer = 3, nclasses::Integer = 1000)
105107
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
106-
reduction = reduction, nclasses = nclasses)
108+
reduction, inchannels, nclasses)
107109
end
108110

109111
"""
110-
DenseNet(nblocks::NTuple{N, <:Integer};
111-
growth_rate = 32, reduction = 0.5, nclasses = 1000)
112+
DenseNet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
113+
inchannels = 3, nclasses::Integer = 1000)
112114
113115
Create a DenseNet model
114116
([reference](https://arxiv.org/abs/1608.06993)).
@@ -124,29 +126,26 @@ See also [`densenet`](#).
124126
struct DenseNet
125127
layers::Any
126128
end
129+
@functor DenseNet
127130

128-
function DenseNet(nblocks::NTuple{N, <:Integer};
129-
growth_rate = 32, reduction = 0.5, nclasses = 1000) where {N}
130-
layers = densenet(nblocks; growth_rate = growth_rate,
131-
reduction = reduction,
132-
nclasses = nclasses)
131+
function DenseNet(nblocks::Vector{<:Integer}; growth_rate::Integer = 32, reduction = 0.5,
132+
inchannels = 3, nclasses::Integer = 1000)
133+
layers = densenet(nblocks; growth_rate, reduction, inchannels, nclasses)
133134
return DenseNet(layers)
134135
end
135136

136-
@functor DenseNet
137-
138137
(m::DenseNet)(x) = m.layers(x)
139138

140139
backbone(m::DenseNet) = m.layers[1]
141140
classifier(m::DenseNet) = m.layers[2]
142141

143-
const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16),
144-
161 => (6, 12, 36, 24),
145-
169 => (6, 12, 32, 32),
146-
201 => (6, 12, 48, 32))
142+
const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16],
143+
161 => [6, 12, 36, 24],
144+
169 => [6, 12, 32, 32],
145+
201 => [6, 12, 48, 32])
147146

148147
"""
149-
DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
148+
DenseNet(config::Integer = 121; pretrain::Bool = false, nclasses::Integer = 1000)
150149
DenseNet(transition_configs::NTuple{N,Integer})
151150
152151
Create a DenseNet model with specified configuration. Currently supported values are (121, 161, 169, 201)
@@ -159,7 +158,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
159158
160159
See also [`Metalhead.densenet`](#).
161160
"""
162-
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
161+
function DenseNet(config::Integer = 121; pretrain::Bool = false, nclasses::Integer = 1000)
163162
_checkconfig(config, keys(DENSENET_CONFIGS))
164163
model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses)
165164
if pretrain

0 commit comments

Comments
 (0)