Skip to content

Commit aaf5e32

Browse files
authored
Merge branch 'master' into resnet-plus
2 parents a4d3f12 + 9edff63 commit aaf5e32

File tree

15 files changed

+219
-56
lines changed

15 files changed

+219
-56
lines changed

Artifacts.toml

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,71 @@
1-
[densenet121]
2-
git-tree-sha1 = "ffc7f7ed1e7f67baca4b76f6c100e0d5042ff063"
1+
[vgg11]
2+
git-tree-sha1 = "78ffe7d74c475cc28175f9e23a545ce2f17b1520"
33
lazy = true
44

5-
[[densenet121.download]]
6-
sha256 = "3fd10f0be70cf072fa7f1358f1fbbe01138440dbcaec1b7c8e007084382c1557"
7-
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/densenet121-0.1.1.tar.gz"
5+
[[vgg11.download]]
6+
sha256 = "9703268c19ca2ae34036ca3588664a96dc0ca8d9d6458db78657299c6879880c"
7+
url = "https://huggingface.co/FluxML/vgg11/resolve/275b202a8a4d10b59eef74285921d278b51fdbdb/vgg11.tar.gz"
88

9-
[googlenet]
10-
git-tree-sha1 = "56cc81845fcca30508fe81da18c7ba0d96d72cdd"
9+
[vgg13]
10+
git-tree-sha1 = "ed006dd09cc24342d4dcd9e2cfaa8c84f063c27a"
1111
lazy = true
1212

13-
[[googlenet.download]]
14-
sha256 = "8ab8d60cc26e81451473badc9dc749b5ffc170a11bc00fb4b203da34fbfdc996"
15-
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/googlenet-0.1.1.tar.gz"
13+
[[vgg13.download]]
14+
sha256 = "ef27949024f5716f7656b3318b06964d76587851f15d9a9127c2b55e5faee288"
15+
url = "https://huggingface.co/FluxML/vgg13/resolve/9593b269ee2c24ce5924d3667496a0d7458a6cb4/vgg13.tar.gz"
16+
17+
[vgg16]
18+
git-tree-sha1 = "759df92ca502324d8624e1c5a940db227908fb9e"
19+
lazy = true
20+
21+
[[vgg16.download]]
22+
sha256 = "f9bad8d9d2c79bc4ebab840f2faded2a0c26c6b2a84f979525964eebcd1886ab"
23+
url = "https://huggingface.co/FluxML/vgg16/resolve/57fdb74b1640815f17eae1a28ae67f0fc1c603db/vgg16.tar.gz"
24+
25+
[vgg19]
26+
git-tree-sha1 = "67f5e867f297086cc911c2cb7985bec8ac1ab23d"
27+
lazy = true
28+
29+
[[vgg19.download]]
30+
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
31+
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"
32+
33+
[resnet18]
34+
git-tree-sha1 = "7b555ed2708e551bfdbcb7e71b25001f4b3731c6"
35+
lazy = true
36+
37+
[[resnet18.download]]
38+
sha256 = "d5782fd873a3072df251c7a4b3cf16efca8ee1da1180ff815bc107833f84bb26"
39+
url = "https://huggingface.co/FluxML/resnet18/resolve/ef9c74047fda4a4a503b1f72553ec05acc90929f/resnet18.tar.gz"
40+
41+
[resnet34]
42+
git-tree-sha1 = "e6e79666cd0fc81cd828508314e6c7f66df8d43d"
43+
lazy = true
44+
45+
[[resnet34.download]]
46+
sha256 = "a8dec13609a86f7a2adac6a44b3af912a863bc2d7319120066c5fdaa04c3f395"
47+
url = "https://huggingface.co/FluxML/resnet34/resolve/42061ddb463902885eea4fcc85275462a5445987/resnet34.tar.gz"
1648

1749
[resnet50]
18-
git-tree-sha1 = "ea3effeaf1ea3969ed5c609f5db5cd0e456ce799"
50+
git-tree-sha1 = "5c442ffd6c51a70c3bc36d849fca86beced446d4"
1951
lazy = true
2052

2153
[[resnet50.download]]
22-
sha256 = "17760ae50e3d59ed7d74c3dfcdb9f0eeaccec1e2ccd095663955c9fed4f318a8"
23-
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/resnet50-0.1.1.tar.gz"
54+
sha256 = "5325920ec91c2a4499ad7e659961f9eaac2b1a3a2905ca6410eaa593ecd35503"
55+
url = "https://huggingface.co/FluxML/resnet50/resolve/10e601719e1cd5b0cab87ce7fd1e8f69a07ce042/resnet50.tar.gz"
2456

25-
[squeezenet]
26-
git-tree-sha1 = "e0e53eb402efe4693417db8cbcc31519e74c8c74"
57+
[resnet101]
58+
git-tree-sha1 = "694a8563ec20fb826334dd663d532b10bb2b3c97"
2759
lazy = true
2860

29-
[[squeezenet.download]]
30-
sha256 = "a3e60f2731296cdf0f32b79badd227eb8dad88a9bee8c828dbe60382869c50f0"
31-
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/squeezenet-0.1.1.tar.gz"
61+
[[resnet101.download]]
62+
sha256 = "f4d737ce640957c30f76bfa642fc9da23e6852d81474d58a2338c1148e55bff0"
63+
url = "https://huggingface.co/FluxML/resnet101/resolve/ea37819163cc3f4a41989a6239ce505e483b112d/resnet101.tar.gz"
3264

33-
[vgg19]
34-
git-tree-sha1 = "072056ec63bf7308cf89885e91852666e191e80a"
65+
[resnet152]
66+
git-tree-sha1 = "55eb883248a276d710d75ecaecfbd2427e50cc0a"
3567
lazy = true
3668

37-
[[vgg19.download]]
38-
sha256 = "0fa000609965604b9d249e84190c30d067d443d73e6c8e340ef09bd013d0bc90"
39-
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/vgg19-0.1.1.tar.gz"
69+
[[resnet152.download]]
70+
sha256 = "57be335e6828d1965c9d11f933d2d41f51e5e534f9bfdbde01c6144fa8862a4d"
71+
url = "https://huggingface.co/FluxML/resnet152/resolve/ba28814d5746643387b5c0e1d2269104e5e9bc8d/resnet152.tar.gz"

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ NNlib = "0.7.34, 0.8"
2323
julia = "1.6"
2424

2525
[extras]
26+
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
2627
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2728

2829
[publish]
@@ -31,4 +32,4 @@ theme = "_flux-theme"
3132
title = "Metalhead.jl"
3233

3334
[targets]
34-
test = ["Test"]
35+
test = ["Images", "Test"]

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
| Model Name | Function | Pre-trained? |
1818
|:-------------------------------------------------|:------------------------------------------------------------------------------------------|:------------:|
19-
| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | N |
20-
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | N |
19+
| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | Y (w/o BN) |
20+
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | Y |
2121
| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.GoogLeNet.html) | N |
2222
| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv3.html) | N |
2323
| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv4.html) | N |
@@ -35,6 +35,8 @@
3535
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvNeXt.html) | N |
3636
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvMixer.html) | N |
3737

38+
To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/docs/dev-guide/contributing.html).
39+
3840
## Getting Started
3941

4042
You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/docs/tutorials/quickstart.html).

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
[deps]
2+
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
23
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
35
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
46
Publish = "f065f642-d108-4f50-8aa5-6749150a895a"

docs/dev-guide/contributing.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Contributing to Metalhead.jl
2+
3+
We welcome contributions from anyone to Metalhead.jl! Thank you for taking the time to make our ecosystem better.
4+
5+
You can contribute by fixing bugs, adding new models, or adding pre-trained weights. If you aren't ready to write some code, but you think you found a bug or have a feature request, please [post an issue](https://github.com/FluxML/Metalhead.jl/issues/new/choose).
6+
7+
Before continuing, make sure you read the [FluxML contributing guide](https://github.com/FluxML/Flux.jl/blob/master/CONTRIBUTING.md) for general guidelines and tips.
8+
9+
## Fixing bugs
10+
11+
To fix a bug in Metalhead.jl, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). It would be helpful to file an issue first so that we can confirm the bug.
12+
13+
## Adding models
14+
15+
To add a new model architecture to Metalhead.jl, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). Keep in mind a few guiding principles for how this package is designed:
16+
17+
- reuse layers from Flux as much as possible (e.g. use `Parallel` before defining a `Bottleneck` struct)
18+
- adhere as closely as possible to a reference such as a published paper (i.e. the structure of your model should follow intuitively from the paper)
19+
- use generic functional builders (e.g. [`resnet`](#) is the core function that builds "ResNet-like" models)
20+
- use multiple dispatch to add convenience constructors that wrap your functional builder
21+
22+
When in doubt, just open a PR! We are more than happy to help review your code to help it align with the rest of the library. After adding a model, you might consider adding some pre-trained weights (see below).
23+
24+
## Adding pre-trained weights
25+
26+
To add pre-trained weights for an existing model or new model, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). Below, we describe the steps you should follow to get there.
27+
28+
All Metalhead.jl model artifacts are hosted using HuggingFace. You can find the FluxML account [here](https://huggingface.co/FluxML). This [documentation from HuggingFace](https://huggingface.co/docs/hub/models) will provide you with an introduction to their ModelHub. In short, the Model Hub is a collection of Git repositories, similar to Julia packages on GitHub. This means you can [make a pull request to our HuggingFace repositories](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to upload updated weight artifacts just like you would make a PR on GitHub to upload code.
29+
30+
1. Train your model or port the weights from another framework.
31+
2. Save the model using [BSON.jl](https://github.com/JuliaIO/BSON.jl) with `BSON.@save "modelname.bson" model`. It is important that your model is saved under the key `model`.
32+
3. Compress the saved model as a tarball using `tar -cvzf modelname.tar.gz modelname.bson`.
33+
4. Obtain the SHAs (see the [Pkg docs](https://pkgdocs.julialang.org/v1/artifacts/#Basic-Usage)). Edit the `Artifacts.toml` file in the Metalhead.jl repository and add entry for your model. You can leave the URL empty for now.
34+
5. Open a PR on Metalhead.jl. Be sure to ping a maintainer (e.g. `@darsnack`) to let us know that you are adding a pre-trained weight. We will create a model repository on HuggingFace if it does not already exist.
35+
6. Open a PR to the [corresponding HuggingFace repo](https://huggingface.co/FluxML). Do this by going to the "Community" tab in the HuggingFace repository. PRs and discussions are shown as the same thing in the HuggingFace web app. You can use your local Git program to make clone the repo and make PRs if you wish. Check out the [guide on PRs to HuggingFace](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for more information.
36+
7. Copy the download URL for the model file that you added to HuggingFace. Make sure to grab the URL for a specific commit and not for the `main` branch.
37+
8. Update your Metalhead.jl PR by adding the URL to the Artifacts.toml.
38+
9. If the tests pass for your weights, we will merge your PR! Your model should pass the `acctest` function in the Metalhead.jl test suite. If your model already exists in the repo, then these tests are already in place, and you can add your model configuration to the `PRETRAINED_MODELS` list in the `runtests.jl` file. Please refer to the ResNet tests as an example.
39+
40+
If you want to fix existing weights, then you can follow the same set of steps.

docs/tutorials/quickstart.md

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,44 @@
55
using Flux, Metalhead
66
```
77

8-
Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the ResNet-18 model.
8+
Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the pre-trained ResNet-18 model.
99
{cell=quickstart}
1010
```julia
1111
using Flux, Metalhead
1212

13-
model = ResNet(18)
13+
model = ResNet(18; pretrain = true)
1414
```
1515

16-
Now, we can use this model with Flux like any other model. Below, we train it on some randomly generated data.
16+
Now, we can use this model with Flux like any other model.
17+
18+
First, let's check the accuracy on a test image from ImageNet.
19+
{cell=quickstart}
20+
```julia
21+
using Images
22+
23+
# test image
24+
img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg"))
25+
```
26+
We'll use the popular [DataAugmentation.jl](https://github.com/lorenzoh/DataAugmentation.jl) library to crop our input image, convert it to a plain array, and normalize the pixels.
27+
{cell=quickstart}
28+
```julia
29+
using DataAugmentation
30+
31+
DATA_MEAN = (0.485, 0.456, 0.406)
32+
DATA_STD = (0.229, 0.224, 0.225)
33+
34+
augmentations = CenterCrop((224, 224)) |>
35+
ImageToTensor() |>
36+
Normalize(DATA_MEAN, DATA_STD)
37+
data = apply(augmentations, Image(img)) |> itemdata
38+
39+
# image net labels
40+
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
41+
42+
Flux.onecold(model(data), labels)
43+
```
44+
45+
Below, we train it on some randomly generated data.
1746

1847
```julia
1948
using Flux: onehotbatch

src/Metalhead.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using BSON
77
using Artifacts, LazyArtifacts
88
using Statistics
99
using MLUtils
10+
using Random
1011

1112
import Functors
1213

src/convnets/inception.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,9 @@ Creates an Xception model.
579579
580580
`Xception` does not currently support pretrained weights.
581581
"""
582-
function Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
582+
function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
583583
layers = xception(; inchannels, dropout, nclasses)
584+
pretrain && loadpretrain!(layers, "xception")
584585
return Xception(layers)
585586
end
586587

src/convnets/resnet.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,5 @@ function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride
180180
num_features = 512 * expansion
181181
classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten,
182182
Dense(num_features, num_classes))
183-
184183
return Chain(Chain(stem, stage_blocks), classifier)
185184
end

src/convnets/vgg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses
171171
fcsize = 4096,
172172
dropout = 0.5)
173173
if pretrain && !batchnorm
174-
loadpretrain!(model, string("VGG", depth))
174+
loadpretrain!(model, string("vgg", depth))
175175
elseif pretrain
176-
loadpretrain!(model, "VGG$(depth)-BN)")
176+
loadpretrain!(model, "vgg$(depth)-bn)")
177177
end
178178
return model
179179
end

0 commit comments

Comments
 (0)