Skip to content

Commit 134d9cb

Browse files
committed
Add back ResNet family and SqueezeNet
Cleanup, but also DenseNets don't work just yet
1 parent 058a37f commit 134d9cb

File tree

7 files changed

+95
-37
lines changed

7 files changed

+95
-37
lines changed

Artifacts.toml

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,66 @@
1+
[resnet101]
2+
git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9"
3+
4+
[[resnet101.download]]
5+
sha256 = "0725f05db5772cfab1024b8d0d6c85ac1fc5a83eb6f0fe02b67b1e689d5a28db"
6+
url = "https://huggingface.co/FluxML/resnet101/resolve/980158099e6917b74ade2b0a9599359f06057d21/resnet101.tar.gz"
7+
8+
[resnet152]
9+
git-tree-sha1 = "85a97464b6cef66e1217ae829d3620651cffab47"
10+
11+
[[resnet152.download]]
12+
sha256 = "a8d30a735ef5649ec40a74a0515ee3d6774499267be06f5f2b372259c5ced8d6"
13+
url = "https://huggingface.co/FluxML/resnet152/resolve/a66a3e1f5056179d167cb2165401950e3890b34d/resnet152.tar.gz"
14+
15+
[resnet18]
16+
git-tree-sha1 = "4ced5a0338c0f0293940f1deb63e1c463125a6ff"
17+
18+
[[resnet18.download]]
19+
sha256 = "9444ef2285f507bd890d2ca852d663749f079110ed19b544e8d91f67f3cc6b83"
20+
url = "https://huggingface.co/FluxML/resnet18/resolve/9b1c6c4f7c5dbe734d80d7d4b5f132ef58bf2467/resnet18.tar.gz"
21+
22+
[resnet34]
23+
git-tree-sha1 = "485519977f375ca1770b3ff3971f61e438823f5a"
24+
25+
[[resnet34.download]]
26+
sha256 = "71ed75be6db0160af7f30be33e2f4a44836310949d1374267510e5803b1fb313"
27+
url = "https://huggingface.co/FluxML/resnet34/resolve/0988ae2d4a86da06eefa6b61edf3e728861e286c/resnet34.tar.gz"
28+
29+
[resnet50]
30+
git-tree-sha1 = "2973be0da60544080105756ecb3951cca2e007da"
31+
32+
[[resnet50.download]]
33+
sha256 = "60ad32eaf160444f3bfdb6f6d81ec1e5c36a3769be7df22aaa75127b16bb1501"
34+
url = "https://huggingface.co/FluxML/resnet50/resolve/1529d6ddca42e3e705cb708c9de6f79188ce8ad5/resnet50.tar.gz"
35+
36+
[resnext101_32x8d]
37+
git-tree-sha1 = "d13a85131b2c0c62ef2af79a09137d4e0760a685"
38+
39+
[[resnext101_32x8d.download]]
40+
sha256 = "aeb48f86f50ee8b0ca7dc01ca0ff5a2d2b2163e43c524203c4a8bd589db9bcc6"
41+
url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/e060f030c445f644112efa2a00e3c544944046e1/resnext101_32x8d.tar.gz"
42+
43+
[resnext101_64x4d]
44+
git-tree-sha1 = "db50f48614e673a40f98fb80a17688b34f42067a"
45+
46+
[[resnext101_64x4d.download]]
47+
sha256 = "89764dd7dc3b3432f0424cb592cec5d9db2fb802ab1646f0e3c2cca2b2e5386b"
48+
url = "https://huggingface.co/FluxML/resnext101_64x4d/resolve/0d0485da04efe5a53289a560d105c42d3ca5435c/resnext101_64x4d.tar.gz"
49+
50+
[resnext50_32x4d]
51+
git-tree-sha1 = "1e7a08a4acae690b635e8d1caa06e75eeb2dd2fe"
52+
53+
[[resnext50_32x4d.download]]
54+
sha256 = "084ccbc40fde07496c401ee2bc389b9cd1d60b1ac3b7ccbfde05479ea91ca707"
55+
url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/150c52c9646fe697030d38ab2be767564fb4f28c/resnext50_32x4d.tar.gz"
56+
57+
[squeezenet]
58+
git-tree-sha1 = "e2eeee109fda46470d657b13669cca09d5ef2f8c"
59+
60+
[[squeezenet.download]]
61+
sha256 = "aebfa06f44767e5ff728b7b67b2d01352b4618bd5d305c26e603aabcd5ba593d"
62+
url = "https://huggingface.co/FluxML/squeezenet/resolve/01ef4221df5260bd992c669ab587eed74df0c39f/squeezenet.tar.gz"
63+
164
[vgg11]
265
git-tree-sha1 = "78ffe7d74c475cc28175f9e23a545ce2f17b1520"
366
lazy = true
@@ -29,3 +92,17 @@ lazy = true
2992
[[vgg19.download]]
3093
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
3194
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"
95+
96+
[wideresnet101]
97+
git-tree-sha1 = "b881a9469fb230faff414ce9f983bc113061ab1c"
98+
99+
[[wideresnet101.download]]
100+
sha256 = "defa61fd80a988bb07bb9db00c692d8d0a30d95e6276add1413fb7f1f3aa2607"
101+
url = "https://huggingface.co/FluxML/wideresnet101/resolve/ad4df1016bb5eba4c10e2d37b049ca5d2a455670/wideresnet101.tar.gz"
102+
103+
[wideresnet50]
104+
git-tree-sha1 = "bbc6bc632e743c992784b5121dcb0f6082c66b1f"
105+
106+
[[wideresnet50.download]]
107+
sha256 = "7596a67b7aba762c2bfce8367055da79a8a3c117bd79ce11124c9f1f5a96c4e3"
108+
url = "https://huggingface.co/FluxML/wideresnet50/resolve/5eca9979f5d9438a684b1e6a5b227f9d5611965a/wideresnet50.tar.gz"

src/convnets/inception/inceptionv3.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ See also [`inceptionv3`](#).
174174
struct Inceptionv3
175175
layers::Any
176176
end
177+
@functor Inceptionv3
177178

178179
function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
179180
nclasses::Integer = 1000)
@@ -184,8 +185,6 @@ function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
184185
return Inceptionv3(layers)
185186
end
186187

187-
@functor Inceptionv3
188-
189188
(m::Inceptionv3)(x) = m.layers(x)
190189

191190
backbone(m::Inceptionv3) = m.layers[1]

src/convnets/resnets/resnet.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ Creates a ResNet model with the specified depth.
1111
- `inchannels`: The number of input channels.
1212
- `nclasses`: the number of output classes
1313
14-
!!! warning
15-
16-
`ResNet` does not currently support pretrained weights.
17-
1814
Advanced users who want more configuration options will be better served by using [`resnet`](#).
1915
"""
2016
struct ResNet
@@ -52,10 +48,6 @@ The number of channels in outer 1x1 convolutions is the same.
5248
- `inchannels`: The number of input channels.
5349
- `nclasses`: the number of output classes
5450
55-
!!! warning
56-
57-
`WideResNet` does not currently support pretrained weights.
58-
5951
Advanced users who want more configuration options will be better served by using [`resnet`](#).
6052
"""
6153
struct WideResNet

src/convnets/resnets/resnext.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
"""
2-
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
3-
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
2+
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
3+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
44
55
Creates a ResNeXt model with the specified depth, cardinality, and base width.
66
((reference)[https://arxiv.org/abs/1611.05431])
77
88
# Arguments
99
1010
- `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model.
11-
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
11+
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet.
12+
Supported configurations are:
13+
- depth 50, cardinality of 32 and base width of 4.
14+
- depth 101, cardinality of 32 and base width of 8.
15+
- depth 101, cardinality of 64 and base width of 4.
1216
- `cardinality`: the number of groups to be used in the 3x3 convolution in each block.
1317
- `base_width`: the number of feature maps in each group.
1418
- `inchannels`: the number of input channels.
1519
- `nclasses`: the number of output classes
1620
17-
!!! warning
18-
19-
`ResNeXt` does not currently support pretrained weights.
20-
2121
Advanced users who want more configuration options will be better served by using [`resnet`](#).
2222
"""
2323
struct ResNeXt
@@ -27,8 +27,8 @@ end
2727

2828
(m::ResNeXt)(x) = m.layers(x)
2929

30-
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
31-
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
30+
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
31+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
3232
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
3333
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
3434
if pretrain

src/convnets/resnets/seresnet.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer =
3030
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
3131
attn_fn = squeeze_excite)
3232
if pretrain
33-
loadpretrain!(layers, string("SEResNet", depth))
33+
loadpretrain!(layers, string("seresnet", depth))
3434
end
3535
return SEResNet(layers)
3636
end
@@ -39,8 +39,8 @@ backbone(m::SEResNet) = m.layers[1]
3939
classifier(m::SEResNet) = m.layers[2]
4040

4141
"""
42-
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
43-
inchannels::Integer = 3, nclasses::Integer = 1000)
42+
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
43+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
4444
4545
Creates a SEResNeXt model with the specified depth, cardinality, and base width.
4646
((reference)[https://arxiv.org/pdf/1709.01507.pdf])
@@ -67,13 +67,13 @@ end
6767

6868
(m::SEResNeXt)(x) = m.layers(x)
6969

70-
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
71-
inchannels::Integer = 3, nclasses::Integer = 1000)
70+
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
71+
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
7272
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
7373
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
7474
attn_fn = squeeze_excite)
7575
if pretrain
76-
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
76+
loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width))
7777
end
7878
return SEResNeXt(layers)
7979
end

src/convnets/squeezenet.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ Create a SqueezeNet
6262
- `inchannels`: number of input channels.
6363
- `nclasses`: the number of output classes.
6464
65-
!!! warning
66-
67-
`SqueezeNet` does not currently support pretrained weights.
68-
6965
See also [`squeezenet`](#).
7066
"""
7167
struct SqueezeNet

test/runtests.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ const PRETRAINED_MODELS = [
1818
(WideResNet, 101),
1919
(ResNeXt, 50, 32, 4),
2020
(ResNeXt, 101, 64, 4),
21-
(ResNeXt, 101, 32, 8),
22-
(DenseNet, 121),
23-
(DenseNet, 161),
24-
(DenseNet, 169),
25-
(DenseNet, 201),
21+
(ResNeXt, 101, 32, 8)
2622
]
2723

2824
function _gc()
@@ -33,7 +29,6 @@ end
3329
function gradtest(model, input)
3430
y, pb = Zygote.pullback(() -> model(input), Flux.params(model))
3531
gs = pb(ones(Float32, size(y)))
36-
3732
# if we make it to here with no error, success!
3833
return true
3934
end
@@ -50,13 +45,12 @@ const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224))
5045
# CHW -> WHC
5146
const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet
5247

53-
# image net labels
48+
# ImageNet labels
5449
const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
5550

5651
function acctest(model)
5752
ypred = model(TEST_X) |> vec
5853
top5 = TEST_LBLS[sortperm(ypred; rev = true)]
59-
6054
return "acoustic guitar" top5
6155
end
6256

0 commit comments

Comments
 (0)