Skip to content

Commit 1e4c669

Browse files
authored
Merge pull request #197 from theabhirath/weights
Add pretrained weights on ImageNet for some models
2 parents 7449985 + d07bd6e commit 1e4c669

File tree

13 files changed

+135
-99
lines changed

13 files changed

+135
-99
lines changed

Artifacts.toml

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,75 @@
1+
[resnet101]
2+
git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9"
3+
lazy = true
4+
5+
[[resnet101.download]]
6+
sha256 = "0725f05db5772cfab1024b8d0d6c85ac1fc5a83eb6f0fe02b67b1e689d5a28db"
7+
url = "https://huggingface.co/FluxML/resnet101/resolve/980158099e6917b74ade2b0a9599359f06057d21/resnet101.tar.gz"
8+
9+
[resnet152]
10+
git-tree-sha1 = "85a97464b6cef66e1217ae829d3620651cffab47"
11+
lazy = true
12+
13+
[[resnet152.download]]
14+
sha256 = "a8d30a735ef5649ec40a74a0515ee3d6774499267be06f5f2b372259c5ced8d6"
15+
url = "https://huggingface.co/FluxML/resnet152/resolve/a66a3e1f5056179d167cb2165401950e3890b34d/resnet152.tar.gz"
16+
17+
[resnet18]
18+
git-tree-sha1 = "4ced5a0338c0f0293940f1deb63e1c463125a6ff"
19+
lazy = true
20+
21+
[[resnet18.download]]
22+
sha256 = "9444ef2285f507bd890d2ca852d663749f079110ed19b544e8d91f67f3cc6b83"
23+
url = "https://huggingface.co/FluxML/resnet18/resolve/9b1c6c4f7c5dbe734d80d7d4b5f132ef58bf2467/resnet18.tar.gz"
24+
25+
[resnet34]
26+
git-tree-sha1 = "485519977f375ca1770b3ff3971f61e438823f5a"
27+
lazy = true
28+
29+
[[resnet34.download]]
30+
sha256 = "71ed75be6db0160af7f30be33e2f4a44836310949d1374267510e5803b1fb313"
31+
url = "https://huggingface.co/FluxML/resnet34/resolve/0988ae2d4a86da06eefa6b61edf3e728861e286c/resnet34.tar.gz"
32+
33+
[resnet50]
34+
git-tree-sha1 = "2973be0da60544080105756ecb3951cca2e007da"
35+
lazy = true
36+
37+
[[resnet50.download]]
38+
sha256 = "60ad32eaf160444f3bfdb6f6d81ec1e5c36a3769be7df22aaa75127b16bb1501"
39+
url = "https://huggingface.co/FluxML/resnet50/resolve/1529d6ddca42e3e705cb708c9de6f79188ce8ad5/resnet50.tar.gz"
40+
41+
[resnext101_32x8d]
42+
git-tree-sha1 = "d13a85131b2c0c62ef2af79a09137d4e0760a685"
43+
lazy = true
44+
45+
[[resnext101_32x8d.download]]
46+
sha256 = "aeb48f86f50ee8b0ca7dc01ca0ff5a2d2b2163e43c524203c4a8bd589db9bcc6"
47+
url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/e060f030c445f644112efa2a00e3c544944046e1/resnext101_32x8d.tar.gz"
48+
49+
[resnext101_64x4d]
50+
git-tree-sha1 = "db50f48614e673a40f98fb80a17688b34f42067a"
51+
lazy = true
52+
53+
[[resnext101_64x4d.download]]
54+
sha256 = "89764dd7dc3b3432f0424cb592cec5d9db2fb802ab1646f0e3c2cca2b2e5386b"
55+
url = "https://huggingface.co/FluxML/resnext101_64x4d/resolve/0d0485da04efe5a53289a560d105c42d3ca5435c/resnext101_64x4d.tar.gz"
56+
57+
[resnext50_32x4d]
58+
git-tree-sha1 = "1e7a08a4acae690b635e8d1caa06e75eeb2dd2fe"
59+
lazy = true
60+
61+
[[resnext50_32x4d.download]]
62+
sha256 = "084ccbc40fde07496c401ee2bc389b9cd1d60b1ac3b7ccbfde05479ea91ca707"
63+
url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/150c52c9646fe697030d38ab2be767564fb4f28c/resnext50_32x4d.tar.gz"
64+
65+
[squeezenet]
66+
git-tree-sha1 = "e2eeee109fda46470d657b13669cca09d5ef2f8c"
67+
lazy = true
68+
69+
[[squeezenet.download]]
70+
sha256 = "aebfa06f44767e5ff728b7b67b2d01352b4618bd5d305c26e603aabcd5ba593d"
71+
url = "https://huggingface.co/FluxML/squeezenet/resolve/01ef4221df5260bd992c669ab587eed74df0c39f/squeezenet.tar.gz"
72+
173
[vgg11]
274
git-tree-sha1 = "78ffe7d74c475cc28175f9e23a545ce2f17b1520"
375
lazy = true
@@ -30,42 +102,18 @@ lazy = true
30102
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
31103
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"
32104

33-
[resnet18]
34-
git-tree-sha1 = "7b555ed2708e551bfdbcb7e71b25001f4b3731c6"
35-
lazy = true
36-
37-
[[resnet18.download]]
38-
sha256 = "d5782fd873a3072df251c7a4b3cf16efca8ee1da1180ff815bc107833f84bb26"
39-
url = "https://huggingface.co/FluxML/resnet18/resolve/ef9c74047fda4a4a503b1f72553ec05acc90929f/resnet18.tar.gz"
40-
41-
[resnet34]
42-
git-tree-sha1 = "e6e79666cd0fc81cd828508314e6c7f66df8d43d"
43-
lazy = true
44-
45-
[[resnet34.download]]
46-
sha256 = "a8dec13609a86f7a2adac6a44b3af912a863bc2d7319120066c5fdaa04c3f395"
47-
url = "https://huggingface.co/FluxML/resnet34/resolve/42061ddb463902885eea4fcc85275462a5445987/resnet34.tar.gz"
48-
49-
[resnet50]
50-
git-tree-sha1 = "5c442ffd6c51a70c3bc36d849fca86beced446d4"
51-
lazy = true
52-
53-
[[resnet50.download]]
54-
sha256 = "5325920ec91c2a4499ad7e659961f9eaac2b1a3a2905ca6410eaa593ecd35503"
55-
url = "https://huggingface.co/FluxML/resnet50/resolve/10e601719e1cd5b0cab87ce7fd1e8f69a07ce042/resnet50.tar.gz"
56-
57-
[resnet101]
58-
git-tree-sha1 = "694a8563ec20fb826334dd663d532b10bb2b3c97"
105+
[wideresnet101]
106+
git-tree-sha1 = "b881a9469fb230faff414ce9f983bc113061ab1c"
59107
lazy = true
60108

61-
[[resnet101.download]]
62-
sha256 = "f4d737ce640957c30f76bfa642fc9da23e6852d81474d58a2338c1148e55bff0"
63-
url = "https://huggingface.co/FluxML/resnet101/resolve/ea37819163cc3f4a41989a6239ce505e483b112d/resnet101.tar.gz"
109+
[[wideresnet101.download]]
110+
sha256 = "defa61fd80a988bb07bb9db00c692d8d0a30d95e6276add1413fb7f1f3aa2607"
111+
url = "https://huggingface.co/FluxML/wideresnet101/resolve/ad4df1016bb5eba4c10e2d37b049ca5d2a455670/wideresnet101.tar.gz"
64112

65-
[resnet152]
66-
git-tree-sha1 = "55eb883248a276d710d75ecaecfbd2427e50cc0a"
113+
[wideresnet50]
114+
git-tree-sha1 = "bbc6bc632e743c992784b5121dcb0f6082c66b1f"
67115
lazy = true
68116

69-
[[resnet152.download]]
70-
sha256 = "57be335e6828d1965c9d11f933d2d41f51e5e534f9bfdbde01c6144fa8862a4d"
71-
url = "https://huggingface.co/FluxML/resnet152/resolve/ba28814d5746643387b5c0e1d2269104e5e9bc8d/resnet152.tar.gz"
117+
[[wideresnet50.download]]
118+
sha256 = "7596a67b7aba762c2bfce8367055da79a8a3c117bd79ce11124c9f1f5a96c4e3"
119+
url = "https://huggingface.co/FluxML/wideresnet50/resolve/5eca9979f5d9438a684b1e6a5b227f9d5611965a/wideresnet50.tar.gz"

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
|:-------------------------------------------------|:------------------------------------------------------------------------------------------|:------------:|
1919
| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | Y (w/o BN) |
2020
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | Y |
21+
| [WideResNet](https://arxiv.org/abs/1605.07146) | [`WideResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.WideResNet.html) | Y |
2122
| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.GoogLeNet.html) | N |
2223
| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv3.html) | N |
2324
| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv4.html) | N |
2425
| [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.InceptionResNetv2.html) | N |
25-
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.SqueezeNet.html) | N |
26+
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | [`SqueezeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.SqueezeNet.html) | Y |
2627
| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.DenseNet.html) | N |
27-
| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | N |
28+
| [ResNeXt](https://arxiv.org/abs/1611.05431) | [`ResNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNeXt.html) | Y |
2829
| [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv1.html) | N |
2930
| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv2.html) | N |
3031
| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.MobileNetv3.html) | N |

src/convnets/densenet.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ Create a DenseNet model
100100
- `nclasses`: the number of output classes
101101
"""
102102
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
103-
reduction = 0.5,
104-
inchannels::Integer = 3, nclasses::Integer = 1000)
103+
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
105104
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
106105
reduction, inchannels, nclasses)
107106
end
@@ -133,11 +132,11 @@ end
133132
function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32,
134133
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
135134
_checkconfig(config, keys(DENSENET_CONFIGS))
136-
model = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
135+
layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
137136
if pretrain
138-
loadpretrain!(model, string("DenseNet", config))
137+
loadpretrain!(layers, string("densenet", config))
139138
end
140-
return model
139+
return DenseNet(layers)
141140
end
142141

143142
(m::DenseNet)(x) = m.layers(x)

src/convnets/inception/inceptionv3.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
## Inceptionv3
2-
31
"""
42
inceptionv3_a(inplanes, pool_proj)
53
@@ -176,6 +174,7 @@ See also [`inceptionv3`](#).
176174
struct Inceptionv3
177175
layers::Any
178176
end
177+
@functor Inceptionv3
179178

180179
function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
181180
nclasses::Integer = 1000)
@@ -186,8 +185,6 @@ function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
186185
return Inceptionv3(layers)
187186
end
188187

189-
@functor Inceptionv3
190-
191188
(m::Inceptionv3)(x) = m.layers(x)
192189

193190
backbone(m::Inceptionv3) = m.layers[1]

src/convnets/resnets/core.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,13 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
280280
end
281281

282282
function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
283-
connection,
284-
classifier_fn)
283+
connection, classifier_fn)
285284
# Build stages of the ResNet
286285
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
287286
backbone = Chain(stem, stage_blocks)
288-
# Build the classifier head
287+
# Add classifier to the backbone
289288
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
290-
classifier = classifier_fn(nfeaturemaps)
291-
return Chain(backbone, classifier)
289+
return Chain(backbone, classifier_fn(nfeaturemaps))
292290
end
293291

294292
function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};

src/convnets/resnets/resnet.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ Creates a ResNet model with the specified depth.
1111
- `inchannels`: The number of input channels.
1212
- `nclasses`: the number of output classes
1313
14-
!!! warning
15-
16-
`ResNet` does not currently support pretrained weights.
17-
1814
Advanced users who want more configuration options will be better served by using [`resnet`](#).
1915
"""
2016
struct ResNet
@@ -27,7 +23,7 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
2723
_checkconfig(depth, keys(RESNET_CONFIGS))
2824
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
2925
if pretrain
30-
loadpretrain!(layers, string("ResNet", depth))
26+
loadpretrain!(layers, string("resnet", depth))
3127
end
3228
return ResNet(layers)
3329
end
@@ -52,10 +48,6 @@ The number of channels in outer 1x1 convolutions is the same.
5248
- `inchannels`: The number of input channels.
5349
- `nclasses`: the number of output classes
5450
55-
!!! warning
56-
57-
`WideResNet` does not currently support pretrained weights.
58-
5951
Advanced users who want more configuration options will be better served by using [`resnet`](#).
6052
"""
6153
struct WideResNet

src/convnets/resnets/resnext.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
"""
2-
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
3-
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
2+
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
3+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
44
55
Creates a ResNeXt model with the specified depth, cardinality, and base width.
66
((reference)[https://arxiv.org/abs/1611.05431])
77
88
# Arguments
99
1010
- `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model.
11-
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
11+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet.
12+
Supported configurations are:
13+
- depth 50, cardinality of 32 and base width of 4.
14+
- depth 101, cardinality of 32 and base width of 8.
15+
- depth 101, cardinality of 64 and base width of 4.
1216
- `cardinality`: the number of groups to be used in the 3x3 convolution in each block.
1317
- `base_width`: the number of feature maps in each group.
1418
- `inchannels`: the number of input channels.
1519
- `nclasses`: the number of output classes
1620
17-
!!! warning
18-
19-
`ResNeXt` does not currently support pretrained weights.
20-
2121
Advanced users who want more configuration options will be better served by using [`resnet`](#).
2222
"""
2323
struct ResNeXt
@@ -27,12 +27,12 @@ end
2727

2828
(m::ResNeXt)(x) = m.layers(x)
2929

30-
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
31-
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
30+
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
31+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
3232
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
3333
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain
35-
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
35+
loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d"))
3636
end
3737
return ResNeXt(layers)
3838
end

src/convnets/resnets/seresnet.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer =
3030
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
3131
attn_fn = squeeze_excite)
3232
if pretrain
33-
loadpretrain!(layers, string("SEResNet", depth))
33+
loadpretrain!(layers, string("seresnet", depth))
3434
end
3535
return SEResNet(layers)
3636
end
@@ -39,8 +39,8 @@ backbone(m::SEResNet) = m.layers[1]
3939
classifier(m::SEResNet) = m.layers[2]
4040

4141
"""
42-
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
43-
inchannels::Integer = 3, nclasses::Integer = 1000)
42+
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
43+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
4444
4545
Creates a SEResNeXt model with the specified depth, cardinality, and base width.
4646
((reference)[https://arxiv.org/pdf/1709.01507.pdf])
@@ -67,13 +67,13 @@ end
6767

6868
(m::SEResNeXt)(x) = m.layers(x)
6969

70-
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
71-
inchannels::Integer = 3, nclasses::Integer = 1000)
70+
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
71+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
7272
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
7373
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7474
attn_fn = squeeze_excite)
7575
if pretrain
76-
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
76+
loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width))
7777
end
7878
return SEResNeXt(layers)
7979
end

src/convnets/squeezenet.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ Create a SqueezeNet
6262
- `inchannels`: number of input channels.
6363
- `nclasses`: the number of output classes.
6464
65-
!!! warning
66-
67-
`SqueezeNet` does not currently support pretrained weights.
68-
6965
See also [`squeezenet`](#).
7066
"""
7167
struct SqueezeNet
@@ -77,7 +73,7 @@ function SqueezeNet(; pretrain::Bool = false, inchannels::Integer = 3,
7773
nclasses::Integer = 1000)
7874
layers = squeezenet(; inchannels, nclasses)
7975
if pretrain
80-
loadpretrain!(layers, "SqueezeNet")
76+
loadpretrain!(layers, "squeezenet")
8177
end
8278
return SqueezeNet(layers)
8379
end

src/layers/mlp.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,19 @@ Creates a classifier head to be used for models.
6666
function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
6767
use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)),
6868
dropout_rate = nothing)
69-
# Pooling
69+
# Decide whether to flatten the input or not
7070
flatten_in_pool = !use_conv && pool_layer !== identity
7171
if use_conv
7272
@assert pool_layer === identity
7373
"`pool_layer` must be identity if `use_conv` is true"
7474
end
75-
global_pool = flatten_in_pool ? [pool_layer, MLUtils.flatten] : [pool_layer]
75+
classifier = []
76+
flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) :
77+
push!(classifier, pool_layer)
78+
# Dropout is applied after the pooling layer
79+
isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate))
7680
# Fully-connected layer
77-
fc = use_conv ? Conv((1, 1), inplanes => nclasses, activation) :
78-
Dense(inplanes => nclasses, activation)
79-
drop = isnothing(dropout_rate) ? [] : [Dropout(dropout_rate)]
80-
return Chain(global_pool..., drop..., fc)
81+
use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) :
82+
push!(classifier, Dense(inplanes => nclasses, activation))
83+
return Chain(classifier...)
8184
end

0 commit comments

Comments
 (0)