Skip to content

Commit ee344a9

Browse files
authored
Merge pull request #168 from theabhirath/style-3
2 parents f97a61d + d4f1d07 commit ee344a9

29 files changed

+1495
-1382
lines changed

.JuliaFormatter.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
style = "sciml"
2+
whitespace_in_kwargs = true
3+
format_docstrings = true
4+
always_for_in = true
5+
join_lines_based_on_source = true
6+
separate_kwargs_with_semicolon = true
7+
always_use_return = true
8+
margin = 92
9+
indent = 4

.git-blame-ignore-revs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# .git-blame-ignore-revs
2+
# Switched to SciML style for code
3+
d5d28f0ef6e1e253ecf3fdbbec2f511836c8767b
4+
70d639de532b046980cbea8d17fb1829e04cccfe

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Metalhead"
22
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
3-
version = "0.7.1"
3+
version = "0.7.2-DEV"
44

55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -16,7 +16,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
BSON = "0.3.2"
1717
Flux = "0.13"
1818
Functors = "0.2"
19-
MLUtils = "0.2"
19+
MLUtils = "0.2.6"
2020
NNlib = "0.7.34, 0.8"
2121
julia = "1.6"
2222

src/Metalhead.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,23 @@ include("vit-based/vit.jl")
3737

3838
include("pretrain.jl")
3939

40-
export AlexNet,
41-
VGG, VGG11, VGG13, VGG16, VGG19,
42-
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
43-
GoogLeNet, Inception3, SqueezeNet,
44-
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
45-
ResNeXt,
46-
MobileNetv1, MobileNetv2, MobileNetv3,
47-
MLPMixer, ResMLP, gMLP,
48-
ViT,
49-
ConvNeXt, ConvMixer
40+
export AlexNet,
41+
VGG, VGG11, VGG13, VGG16, VGG19,
42+
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
43+
GoogLeNet, Inception3, SqueezeNet,
44+
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
45+
ResNeXt,
46+
MobileNetv1, MobileNetv2, MobileNetv3,
47+
MLPMixer, ResMLP, gMLP,
48+
ViT,
49+
ConvNeXt, ConvMixer
5050

5151
# use Flux._big_show to pretty print large models
52-
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
52+
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet,
53+
:ResNeXt,
5354
:MobileNetv1, :MobileNetv2, :MobileNetv3,
5455
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt, :ConvMixer)
55-
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
56+
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
5657
end
5758

5859
end # module

src/convnets/alexnet.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,27 @@ Create an AlexNet model
55
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
66
77
# Arguments
8-
- `nclasses`: the number of output classes
8+
9+
- `nclasses`: the number of output classes
910
"""
1011
function alexnet(; nclasses = 1000)
11-
layers = Chain(Chain(Conv((11, 11), 3 => 64, stride = (4, 4), relu, pad = (2, 2)),
12-
MaxPool((3, 3), stride = (2, 2)),
13-
Conv((5, 5), 64 => 192, relu, pad = (2, 2)),
14-
MaxPool((3, 3), stride = (2, 2)),
15-
Conv((3, 3), 192 => 384, relu, pad = (1, 1)),
16-
Conv((3, 3), 384 => 256, relu, pad = (1, 1)),
17-
Conv((3, 3), 256 => 256, relu, pad = (1, 1)),
18-
MaxPool((3, 3), stride = (2, 2)),
19-
AdaptiveMeanPool((6,6))),
20-
Chain(MLUtils.flatten,
21-
Dropout(0.5),
22-
Dense(256 * 6 * 6, 4096, relu),
23-
Dropout(0.5),
24-
Dense(4096, 4096, relu),
25-
Dense(4096, nclasses)))
26-
27-
return layers
12+
layers = Chain(Chain(Conv((11, 11), 3 => 64, relu; stride = (4, 4), pad = (2, 2)),
13+
MaxPool((3, 3); stride = (2, 2)),
14+
Conv((5, 5), 64 => 192, relu; pad = (2, 2)),
15+
MaxPool((3, 3); stride = (2, 2)),
16+
Conv((3, 3), 192 => 384, relu; pad = (1, 1)),
17+
Conv((3, 3), 384 => 256, relu; pad = (1, 1)),
18+
Conv((3, 3), 256 => 256, relu; pad = (1, 1)),
19+
MaxPool((3, 3); stride = (2, 2)),
20+
AdaptiveMeanPool((6, 6))),
21+
Chain(MLUtils.flatten,
22+
Dropout(0.5),
23+
Dense(256 * 6 * 6, 4096, relu),
24+
Dropout(0.5),
25+
Dense(4096, 4096, relu),
26+
Dense(4096, nclasses)))
27+
28+
return layers
2829
end
2930

3031
"""
@@ -34,21 +35,22 @@ Create a `AlexNet`.
3435
See also [`alexnet`](#).
3536
3637
!!! warning
38+
3739
`AlexNet` does not currently support pretrained weights.
3840
3941
# Arguments
40-
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
41-
- `nclasses`: the number of output classes
42+
43+
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
44+
- `nclasses`: the number of output classes
4245
"""
4346
struct AlexNet
44-
layers
47+
layers::Any
4548
end
4649

4750
function AlexNet(; pretrain = false, nclasses = 1000)
48-
layers = alexnet(nclasses = nclasses)
49-
pretrain && loadpretrain!(layers, "AlexNet")
50-
51-
AlexNet(layers)
51+
layers = alexnet(; nclasses = nclasses)
52+
pretrain && loadpretrain!(layers, "AlexNet")
53+
return AlexNet(layers)
5254
end
5355

5456
@functor AlexNet

src/convnets/convmixer.jl

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,35 @@ Creates a ConvMixer model.
66
([reference](https://arxiv.org/abs/2201.09792))
77
88
# Arguments
9-
- `planes`: number of planes in the output of each block
10-
- `depth`: number of layers
11-
- `inchannels`: number of channels in the input
12-
- `kernel_size`: kernel size of the convolutional layers
13-
- `patch_size`: size of the patches
14-
- `activation`: activation function used after the convolutional layers
15-
- `nclasses`: number of classes in the output
9+
10+
- `planes`: number of planes in the output of each block
11+
- `depth`: number of layers
12+
- `inchannels`: number of channels in the input
13+
- `kernel_size`: kernel size of the convolutional layers
14+
- `patch_size`: size of the patches
15+
- `activation`: activation function used after the convolutional layers
16+
- `nclasses`: number of classes in the output
1617
"""
1718
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
1819
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
19-
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
20-
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
21-
preact = true, groups = planes, pad = SamePad())), +),
22-
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
23-
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
24-
return Chain(Chain(stem..., Chain(blocks)), head)
20+
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
21+
stride = patch_size[1])
22+
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
23+
preact = true, groups = planes,
24+
pad = SamePad())), +),
25+
conv_bn((1, 1), planes, planes, activation; preact = true)...)
26+
for _ in 1:depth]
27+
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
28+
return Chain(Chain(stem..., Chain(blocks)), head)
2529
end
2630

2731
convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
28-
:patch_size => (7, 7)),
32+
:patch_size => (7, 7)),
2933
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
30-
:patch_size => (7, 7)),
31-
:large => Dict(:planes => 1024, :depth => 20, :kernel_size => (9, 9),
32-
:patch_size => (7, 7)))
34+
:patch_size => (7, 7)),
35+
:large => Dict(:planes => 1024, :depth => 20,
36+
:kernel_size => (9, 9),
37+
:patch_size => (7, 7)))
3338

3439
"""
3540
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
@@ -38,22 +43,24 @@ Creates a ConvMixer model.
3843
([reference](https://arxiv.org/abs/2201.09792))
3944
4045
# Arguments
41-
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
42-
- `inchannels`: number of channels in the input
43-
- `activation`: activation function used after the convolutional layers
44-
- `nclasses`: number of classes in the output
46+
47+
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
48+
- `inchannels`: number of channels in the input
49+
- `activation`: activation function used after the convolutional layers
50+
- `nclasses`: number of classes in the output
4551
"""
4652
struct ConvMixer
47-
layers
53+
layers::Any
4854
end
4955

5056
function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
51-
planes = convmixer_config[mode][:planes]
52-
depth = convmixer_config[mode][:depth]
53-
kernel_size = convmixer_config[mode][:kernel_size]
54-
patch_size = convmixer_config[mode][:patch_size]
55-
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses)
56-
return ConvMixer(layers)
57+
planes = convmixer_config[mode][:planes]
58+
depth = convmixer_config[mode][:depth]
59+
kernel_size = convmixer_config[mode][:kernel_size]
60+
patch_size = convmixer_config[mode][:patch_size]
61+
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
62+
nclasses)
63+
return ConvMixer(layers)
5764
end
5865

5966
@functor ConvMixer

src/convnets/convnext.jl

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@ Creates a single block of ConvNeXt.
55
([reference](https://arxiv.org/abs/2201.03545))
66
77
# Arguments:
8-
- `planes`: number of input channels.
9-
- `drop_path_rate`: Stochastic depth rate.
10-
- `λ`: Init value for LayerScale
8+
9+
- `planes`: number of input channels.
10+
- `drop_path_rate`: Stochastic depth rate.
11+
- `λ`: Init value for LayerScale
1112
"""
12-
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
13-
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
14-
swapdims((3, 1, 2, 4)),
15-
LayerNorm(planes; ϵ = 1f-6),
16-
mlp_block(planes, 4 * planes),
17-
LayerScale(planes, λ),
18-
swapdims((2, 3, 1, 4)),
19-
DropPath(drop_path_rate)), +)
20-
return layers
13+
function convnextblock(planes, drop_path_rate = 0.0, λ = 1.0f-6)
14+
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
15+
swapdims((3, 1, 2, 4)),
16+
LayerNorm(planes; ϵ = 1.0f-6),
17+
mlp_block(planes, 4 * planes),
18+
LayerScale(planes, λ),
19+
swapdims((2, 3, 1, 4)),
20+
DropPath(drop_path_rate)), +)
21+
return layers
2122
end
2223

2324
"""
@@ -27,52 +28,59 @@ Creates the layers for a ConvNeXt model.
2728
([reference](https://arxiv.org/abs/2201.03545))
2829
2930
# Arguments:
30-
- `inchannels`: number of input channels.
31-
- `depths`: list with configuration for depth of each block
32-
- `planes`: list with configuration for number of output channels in each block
33-
- `drop_path_rate`: Stochastic depth rate.
34-
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
35-
- `nclasses`: number of output classes
31+
32+
- `inchannels`: number of input channels.
33+
- `depths`: list with configuration for depth of each block
34+
- `planes`: list with configuration for number of output channels in each block
35+
- `drop_path_rate`: Stochastic depth rate.
36+
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
37+
- `nclasses`: number of output classes
3638
"""
37-
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
38-
@assert length(depths) == length(planes) "`planes` should have exactly one value for each block"
39-
40-
downsample_layers = []
41-
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
42-
ChannelLayerNorm(planes[1]; ϵ = 1f-6))
43-
push!(downsample_layers, stem)
44-
for m in 1:length(depths) - 1
45-
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1f-6),
46-
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
47-
push!(downsample_layers, downsample_layer)
48-
end
49-
50-
stages = []
51-
dp_rates = LinRange{Float32}(0., drop_path_rate, sum(depths))
52-
cur = 0
53-
for i in 1:length(depths)
54-
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
55-
cur += depths[i]
56-
end
57-
58-
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
59-
head = Chain(GlobalMeanPool(),
60-
MLUtils.flatten,
61-
LayerNorm(planes[end]),
62-
Dense(planes[end], nclasses))
63-
64-
return Chain(Chain(backbone), head)
39+
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
40+
nclasses = 1000)
41+
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"
42+
43+
downsample_layers = []
44+
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
45+
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
46+
push!(downsample_layers, stem)
47+
for m in 1:(length(depths) - 1)
48+
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
49+
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
50+
push!(downsample_layers, downsample_layer)
51+
end
52+
53+
stages = []
54+
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
55+
cur = 0
56+
for i in 1:length(depths)
57+
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
58+
cur += depths[i]
59+
end
60+
61+
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
62+
head = Chain(GlobalMeanPool(),
63+
MLUtils.flatten,
64+
LayerNorm(planes[end]),
65+
Dense(planes[end], nclasses))
66+
67+
return Chain(Chain(backbone), head)
6568
end
6669

6770
# Configurations for ConvNeXt models
68-
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], :planes => [96, 192, 384, 768]),
69-
:small => Dict(:depths => [3, 3, 27, 3], :planes => [96, 192, 384, 768]),
70-
:base => Dict(:depths => [3, 3, 27, 3], :planes => [128, 256, 512, 1024]),
71-
:large => Dict(:depths => [3, 3, 27, 3], :planes => [192, 384, 768, 1536]),
72-
:xlarge => Dict(:depths => [3, 3, 27, 3], :planes => [256, 512, 1024, 2048]))
71+
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
72+
:planes => [96, 192, 384, 768]),
73+
:small => Dict(:depths => [3, 3, 27, 3],
74+
:planes => [96, 192, 384, 768]),
75+
:base => Dict(:depths => [3, 3, 27, 3],
76+
:planes => [128, 256, 512, 1024]),
77+
:large => Dict(:depths => [3, 3, 27, 3],
78+
:planes => [192, 384, 768, 1536]),
79+
:xlarge => Dict(:depths => [3, 3, 27, 3],
80+
:planes => [256, 512, 1024, 2048]))
7381

7482
struct ConvNeXt
75-
layers
83+
layers::Any
7684
end
7785

7886
"""
@@ -82,20 +90,21 @@ Creates a ConvNeXt model.
8290
([reference](https://arxiv.org/abs/2201.03545))
8391
8492
# Arguments:
85-
- `inchannels`: number of input channels.
86-
- `drop_path_rate`: Stochastic depth rate.
87-
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
88-
- `nclasses`: number of output classes
93+
94+
- `inchannels`: number of input channels.
95+
- `drop_path_rate`: Stochastic depth rate.
96+
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
97+
- `nclasses`: number of output classes
8998
9099
See also [`Metalhead.convnext`](#).
91100
"""
92-
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6,
101+
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
93102
nclasses = 1000)
94-
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
95-
depths = convnext_configs[mode][:depths]
96-
planes = convnext_configs[mode][:planes]
97-
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
98-
return ConvNeXt(layers)
103+
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
104+
depths = convnext_configs[mode][:depths]
105+
planes = convnext_configs[mode][:planes]
106+
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
107+
return ConvNeXt(layers)
99108
end
100109

101110
(m::ConvNeXt)(x) = m.layers(x)

0 commit comments

Comments
 (0)