Skip to content

Commit dd0bb6d

Browse files
committed
Add remaining weights
1 parent e1fd820 commit dd0bb6d

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

Artifacts.toml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,67 @@ lazy = true
55
[[vgg11.download]]
66
sha256 = "9703268c19ca2ae34036ca3588664a96dc0ca8d9d6458db78657299c6879880c"
77
url = "https://huggingface.co/FluxML/vgg11/resolve/275b202a8a4d10b59eef74285921d278b51fdbdb/vgg11.tar.gz"
8+
9+
[vgg13]
10+
git-tree-sha1 = "ed006dd09cc24342d4dcd9e2cfaa8c84f063c27a"
11+
lazy = true
12+
13+
[[vgg13.download]]
14+
sha256 = "ef27949024f5716f7656b3318b06964d76587851f15d9a9127c2b55e5faee288"
15+
url = "https://huggingface.co/FluxML/vgg13/resolve/9593b269ee2c24ce5924d3667496a0d7458a6cb4/vgg13.tar.gz"
16+
17+
[vgg16]
18+
git-tree-sha1 = "759df92ca502324d8624e1c5a940db227908fb9e"
19+
lazy = true
20+
21+
[[vgg16.download]]
22+
sha256 = "f9bad8d9d2c79bc4ebab840f2faded2a0c26c6b2a84f979525964eebcd1886ab"
23+
url = "https://huggingface.co/FluxML/vgg16/resolve/57fdb74b1640815f17eae1a28ae67f0fc1c603db/vgg16.tar.gz"
24+
25+
[vgg19]
26+
git-tree-sha1 = "67f5e867f297086cc911c2cb7985bec8ac1ab23d"
27+
lazy = true
28+
29+
[[vgg19.download]]
30+
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
31+
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"
32+
33+
[resnet18]
34+
git-tree-sha1 = "7b555ed2708e551bfdbcb7e71b25001f4b3731c6"
35+
lazy = true
36+
37+
[[resnet18.download]]
38+
sha256 = "d5782fd873a3072df251c7a4b3cf16efca8ee1da1180ff815bc107833f84bb26"
39+
url = "https://huggingface.co/FluxML/resnet18/resolve/ef9c74047fda4a4a503b1f72553ec05acc90929f/resnet18.tar.gz"
40+
41+
[resnet34]
42+
git-tree-sha1 = "e6e79666cd0fc81cd828508314e6c7f66df8d43d"
43+
lazy = true
44+
45+
[[resnet34.download]]
46+
sha256 = "a8dec13609a86f7a2adac6a44b3af912a863bc2d7319120066c5fdaa04c3f395"
47+
url = "https://huggingface.co/FluxML/resnet34/resolve/42061ddb463902885eea4fcc85275462a5445987/resnet34.tar.gz"
48+
49+
[resnet50]
50+
git-tree-sha1 = "5c442ffd6c51a70c3bc36d849fca86beced446d4"
51+
lazy = true
52+
53+
[[resnet50.download]]
54+
sha256 = "5325920ec91c2a4499ad7e659961f9eaac2b1a3a2905ca6410eaa593ecd35503"
55+
url = "https://huggingface.co/FluxML/resnet50/resolve/10e601719e1cd5b0cab87ce7fd1e8f69a07ce042/resnet50.tar.gz"
56+
57+
[resnet101]
58+
git-tree-sha1 = "694a8563ec20fb826334dd663d532b10bb2b3c97"
59+
lazy = true
60+
61+
[[resnet101.download]]
62+
sha256 = "f4d737ce640957c30f76bfa642fc9da23e6852d81474d58a2338c1148e55bff0"
63+
url = "https://huggingface.co/FluxML/resnet101/resolve/ea37819163cc3f4a41989a6239ce505e483b112d/resnet101.tar.gz"
64+
65+
[resnet152]
66+
git-tree-sha1 = "55eb883248a276d710d75ecaecfbd2427e50cc0a"
67+
lazy = true
68+
69+
[[resnet152.download]]
70+
sha256 = "57be335e6828d1965c9d11f933d2d41f51e5e534f9bfdbde01c6144fa8862a4d"
71+
url = "https://huggingface.co/FluxML/resnet152/resolve/ba28814d5746643387b5c0e1d2269104e5e9bc8d/resnet152.tar.gz"

src/convnets/resnet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,6 @@ function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000)
258258
@assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))"
259259
config, block = resnet_config[depth]
260260
model = ResNet(config...; block = block, nclasses = nclasses)
261-
pretrain && loadpretrain!(model, string("ResNet", depth))
261+
pretrain && loadpretrain!(model, string("resnet", depth))
262262
return model
263263
end

test/runtests.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,17 @@ using Flux
33
using Flux: Zygote
44
using Images
55

6-
const PRETRAINED_MODELS = [(VGG, 11, false)]
6+
const PRETRAINED_MODELS = [
7+
(VGG, 11, false),
8+
(VGG, 13, false),
9+
(VGG, 16, false),
10+
(VGG, 19, false),
11+
(ResNet, 18),
12+
(ResNet, 34),
13+
(ResNet, 50),
14+
(ResNet, 101),
15+
(ResNet, 152),
16+
]
717

818
function gradtest(model, input)
919
y, pb = Zygote.pullback(() -> model(input), Flux.params(model))
@@ -29,9 +39,9 @@ const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,
2939
const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
3040

3141
function acctest(model)
32-
@show ypred = Flux.onecold(model(TEST_X), TEST_LBLS)
42+
ypred = Flux.onecold(model(TEST_X), TEST_LBLS)
3343

34-
return ypred == "acoustic guitar"
44+
return only(ypred) == "acoustic guitar"
3545
end
3646

3747
x_224 = rand(Float32, 224, 224, 3, 1)

0 commit comments

Comments
 (0)