Skip to content

Commit 92f2d0d

Browse files
committed
Merge branch 'master' into cutmix. Fixup a few issues.
2 parents 670c61b + 6e9d617 commit 92f2d0d

File tree

145 files changed

+13651
-4255
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+13651
-4255
lines changed

.github/workflows/tests.yml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Python tests
2+
3+
on:
4+
push:
5+
branches: [ master ]
6+
pull_request:
7+
branches: [ master ]
8+
9+
env:
10+
OMP_NUM_THREADS: 2
11+
MKL_NUM_THREADS: 2
12+
13+
jobs:
14+
test:
15+
name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
16+
strategy:
17+
matrix:
18+
os: [ubuntu-latest, macOS-latest]
19+
python: ['3.8']
20+
torch: ['1.5.0']
21+
torchvision: ['0.6.0']
22+
runs-on: ${{ matrix.os }}
23+
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Set up Python ${{ matrix.python }}
27+
uses: actions/setup-python@v1
28+
with:
29+
python-version: ${{ matrix.python }}
30+
- name: Install testing dependencies
31+
run: |
32+
python -m pip install --upgrade pip
33+
pip install pytest pytest-timeout
34+
- name: Install torch on mac
35+
if: startsWith(matrix.os, 'macOS')
36+
run: pip install torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
37+
- name: Install torch on ubuntu
38+
if: startsWith(matrix.os, 'ubuntu')
39+
run: pip install torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
40+
- name: Install requirements
41+
run: |
42+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
43+
pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11
44+
- name: Run tests
45+
run: |
46+
pytest -vv --durations=0 ./tests

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
include timm/models/pruned/*.txt
2+

README.md

Lines changed: 135 additions & 376 deletions
Large diffs are not rendered by default.

avg_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
EMA (exponential moving average) of the model weights or performing SWA (stochastic
1010
weight averaging), but post-training.
1111
12-
Hacked together by Ross Wightman (https://github.com/rwightman)
12+
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
1313
"""
1414
import torch
1515
import argparse

clean_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256
66
calculation for model zoo compatibility.
77
8-
Hacked together by Ross Wightman (https://github.com/rwightman)
8+
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
99
"""
1010
import torch
1111
import argparse

docs/archived_changes.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Archived Changes
2+
3+
### Feb 29, 2020
4+
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
5+
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
6+
* overall results similar to a bit better training from scratch on a few smaller models tried
7+
* performance early in training seems consistently improved but less difference by end
8+
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
9+
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
10+
11+
### Feb 18, 2020
12+
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
13+
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
14+
* ResNet downsample paths now properly support dilation (output stride != 32) for avg_pool ('D' variant) and 3x3 (SENets) networks
15+
* Add Selective Kernel Nets on top of ResNet base, pretrained weights
16+
* skresnet18 - 73% top-1
17+
* skresnet34 - 76.9% top-1
18+
* skresnext50_32x4d (equiv to SKNet50) - 80.2% top-1
19+
* ECA and CECA (circular padding) attention layer contributed by [Chris Ha](https://github.com/VRandme)
20+
* CBAM attention experiment (not the best results so far, may remove)
21+
* Attention factory to allow dynamically selecting one of SE, ECA, CBAM in the `.se` position for all ResNets
22+
* Add DropBlock and DropPath (formerly DropConnect for EfficientNet/MobileNetv3) support to all ResNet variants
23+
* Full dataset results updated that incl NoisyStudent weights and 2 of the 3 SK weights
24+
25+
### Feb 12, 2020
26+
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
27+
28+
### Feb 6, 2020
29+
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
30+
31+
### Feb 1/2, 2020
32+
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
33+
* Update results csv files on all models for ImageNet validation and three other test sets
34+
* Push PyPi package update
35+
36+
### Jan 31, 2020
37+
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
38+
39+
### Jan 11/12, 2020
40+
* Master may be a bit unstable wrt to training, these changes have been tested but not all combos
41+
* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset
42+
* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper
43+
* ResNet-50 AugMix trained model w/ 79% top-1 added
44+
* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd'
45+
46+
### Jan 3, 2020
47+
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
48+
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
49+
50+
### Dec 30, 2019
51+
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
52+
53+
### Dec 28, 2019
54+
* Add new model weights and training hparams (see Training Hparams section)
55+
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
56+
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
57+
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
58+
* deep stem (32, 32, 64), avgpool downsample
59+
* stem/dowsample from bag-of-tricks paper
60+
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
61+
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
62+
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
63+
64+
### Dec 23, 2019
65+
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
66+
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
67+
68+
### Dec 4, 2019
69+
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
70+
71+
### Nov 29, 2019
72+
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
73+
* AdvProp weights added
74+
* Official TF MobileNetv3 weights added
75+
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
76+
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
77+
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
78+
* `forward_features` always returns unpooled feature maps now
79+
* Reasonable chance I broke something... let me know
80+
81+
### Nov 22, 2019
82+
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
83+
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.

docs/changes.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Recent Changes
2+
3+
### Aug 5, 2020
4+
Universal feature extraction, new models, new weights, new test sets.
5+
6+
* All models support the `features_only=True` argument for `create_model` call to return a network that extracts features from the deepest layer at each stride.
7+
* New models
8+
* CSPResNet, CSPResNeXt, CSPDarkNet, DarkNet
9+
* ReXNet
10+
* (Modified Aligned) Xception41/65/71 (a proper port of TF models)
11+
* New trained weights
12+
* SEResNet50 - 80.3 top-1
13+
* CSPDarkNet53 - 80.1 top-1
14+
* CSPResNeXt50 - 80.0 top-1
15+
* DPN68b - 79.2 top-1
16+
* EfficientNet-Lite0 (non-TF ver) - 75.5 (submitted by [@hal-314](https://github.com/hal-314))
17+
* Add 'real' labels for ImageNet and ImageNet-Renditions test set, see [`results/README.md`](results/README.md)
18+
* Test set ranking/top-n diff script by [@KushajveerSingh](https://github.com/KushajveerSingh)
19+
* Train script and loader/transform tweaks to punch through more aug arguments
20+
* README and documentation overhaul. See initial (WIP) documentation at https://rwightman.github.io/pytorch-image-models/
21+
* adamp and sgdp optimizers added by [@hellbell](https://github.com/hellbell)
22+
23+
### June 11, 2020
24+
Bunch of changes:
25+
26+
* DenseNet models updated with memory efficient addition from torchvision (fixed a bug), blur pooling and deep stem additions
27+
* VoVNet V1 and V2 models added, 39 V2 variant (ese_vovnet_39b) trained to 79.3 top-1
28+
* Activation factory added along with new activations:
29+
* select act at model creation time for more flexibility in using activations compatible with scripting or tracing (ONNX export)
30+
* hard_mish (experimental) added with memory-efficient grad, along with ME hard_swish
31+
* context mgr for setting exportable/scriptable/no_jit states
32+
* Norm + Activation combo layers added with initial trial support in DenseNet and VoVNet along with impl of EvoNorm and InplaceAbn wrapper that fit the interface
33+
* Torchscript works for all but two of the model types as long as using Pytorch 1.5+, tests added for this
34+
* Some import cleanup and classifier reset changes, all models will have classifier reset to nn.Identity on reset_classifer(0) call
35+
* Prep for 0.1.28 pip release
36+
37+
### May 12, 2020
38+
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
39+
40+
### May 3, 2020
41+
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
42+
43+
### May 1, 2020
44+
* Merged a number of execellent contributions in the ResNet model family over the past month
45+
* BlurPool2D and resnetblur models initiated by [Chris Ha](https://github.com/VRandme), I trained resnetblur50 to 79.3.
46+
* TResNet models and SpaceToDepth, AntiAliasDownsampleLayer layers by [mrT23](https://github.com/mrT23)
47+
* ecaresnet (50d, 101d, light) models and two pruned variants using pruning as per (https://arxiv.org/abs/2002.08258) by [Yonathan Aflalo](https://github.com/yoniaflalo)
48+
* 200 pretrained models in total now with updated results csv in results folder
49+
50+
### April 5, 2020
51+
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
52+
* 3.5M param MobileNet-V2 100 @ 73%
53+
* 4.5M param MobileNet-V2 110d @ 75%
54+
* 6.1M param MobileNet-V2 140 @ 76.5%
55+
* 5.8M param MobileNet-V2 120d @ 77.3%
56+
57+
### March 18, 2020
58+
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
59+
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)

docs/feature_extraction.md

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Feature Extraction
2+
3+
All of the models in `timm` have consistent mechanisms for obtaining various types of features from the model for tasks besides classification.
4+
5+
## Penultimate Layer Features (Pre-Classifier Features)
6+
7+
The features from the penultimate model layer can be obtained in severay ways without requiring model surgery (although feel free to do surgery). One must first decide if they want pooled or un-pooled features.
8+
9+
### Unpooled
10+
11+
There are three ways to obtain unpooled features.
12+
13+
Without modifying the network, one can call `model.forward_features(input)` on any model instead of the usual `model(input)`. This will bypass the head classifier and global pooling for networks.
14+
15+
If one wants to explicitly modify the network to return unpooled features, they can either create the model without a classifier and pooling, or remove it later. Both paths remove the parameters associated with the classifier from the network.
16+
17+
#### forward_features()
18+
```python hl_lines="3 6"
19+
import torch
20+
import timm
21+
m = timm.create_model('xception41', pretrained=True)
22+
o = m(torch.randn(2, 3, 299, 299))
23+
print(f'Original shape: {o.shape}')
24+
o = m.forward_features(torch.randn(2, 3, 299, 299))
25+
print(f'Unpooled shape: {o.shape}')
26+
```
27+
Output:
28+
```text
29+
Original shape: torch.Size([2, 1000])
30+
Unpooled shape: torch.Size([2, 2048, 10, 10])
31+
```
32+
33+
#### Create with no classifier and pooling
34+
```python hl_lines="3"
35+
import torch
36+
import timm
37+
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
38+
o = m(torch.randn(2, 3, 224, 224))
39+
print(f'Unpooled shape: {o.shape}')
40+
```
41+
Output:
42+
```text
43+
Unpooled shape: torch.Size([2, 2048, 7, 7])
44+
```
45+
46+
#### Remove it later
47+
```python hl_lines="3 6"
48+
import torch
49+
import timm
50+
m = timm.create_model('densenet121', pretrained=True)
51+
o = m(torch.randn(2, 3, 224, 224))
52+
print(f'Original shape: {o.shape}')
53+
m.reset_classifier(0, '')
54+
o = m(torch.randn(2, 3, 224, 224))
55+
print(f'Unpooled shape: {o.shape}')
56+
```
57+
Output:
58+
```text
59+
Original shape: torch.Size([2, 1000])
60+
Unpooled shape: torch.Size([2, 1024, 7, 7])
61+
```
62+
63+
### Pooled
64+
65+
To modify the network to return pooled features, one can use `forward_features()` and pool/flatten the result themselves, or modify the network like above but keep pooling intact.
66+
67+
#### Create with no classifier
68+
```python hl_lines="3"
69+
import torch
70+
import timm
71+
m = timm.create_model('resnet50', pretrained=True, num_classes=0)
72+
o = m(torch.randn(2, 3, 224, 224))
73+
print(f'Pooled shape: {o.shape}')
74+
```
75+
Output:
76+
```text
77+
Pooled shape: torch.Size([2, 2048])
78+
```
79+
80+
#### Remove it later
81+
```python hl_lines="3 6"
82+
import torch
83+
import timm
84+
m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
85+
o = m(torch.randn(2, 3, 224, 224))
86+
print(f'Original shape: {o.shape}')
87+
m.reset_classifier(0)
88+
o = m(torch.randn(2, 3, 224, 224))
89+
print(f'Pooled shape: {o.shape}')
90+
```
91+
Output:
92+
```text
93+
Pooled shape: torch.Size([2, 1024])
94+
```
95+
96+
97+
## Multi-scale Feature Maps (Feature Pyramid)
98+
99+
Object detection, segmentation, keypoint, and a variety of dense pixel tasks require access to feature maps from the backbone network at multiple scales. This is often done by modifying the original classification network. Since each network varies quite a bit in structure, it's not uncommon to see only a few backbones supported in any given obj detection or segmentation library.
100+
101+
`timm` allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels.
102+
103+
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default 5 strides will be output from most models (not all have that many), with the first starting at 2 (some start at 1 or 4).
104+
105+
### Create a feature map extraction model
106+
```python hl_lines="3"
107+
import torch
108+
import timm
109+
m = timm.create_model('resnest26d', features_only=True, pretrained=True)
110+
o = m(torch.randn(2, 3, 224, 224))
111+
for x in o:
112+
print(x.shape)
113+
```
114+
Output:
115+
```text
116+
torch.Size([2, 64, 112, 112])
117+
torch.Size([2, 256, 56, 56])
118+
torch.Size([2, 512, 28, 28])
119+
torch.Size([2, 1024, 14, 14])
120+
torch.Size([2, 2048, 7, 7])
121+
```
122+
123+
### Query the feature information
124+
125+
After a feature backbone has been created, it can be queried to provide channel or resolution reduction information to the downstream heads without requiring static config or hardcoded constants. The `.feature_info` attribute is a class encapsulating the information about the feature extraction points.
126+
127+
```python hl_lines="3 4"
128+
import torch
129+
import timm
130+
m = timm.create_model('regnety_032', features_only=True, pretrained=True)
131+
print(f'Feature channels: {m.feature_info.channels()}')
132+
o = m(torch.randn(2, 3, 224, 224))
133+
for x in o:
134+
print(x.shape)
135+
```
136+
Output:
137+
```text
138+
Feature channels: [32, 72, 216, 576, 1512]
139+
torch.Size([2, 32, 112, 112])
140+
torch.Size([2, 72, 56, 56])
141+
torch.Size([2, 216, 28, 28])
142+
torch.Size([2, 576, 14, 14])
143+
torch.Size([2, 1512, 7, 7])
144+
```
145+
146+
### Select specific feature levels or limit the stride
147+
148+
There are to additional creation arguments impacting the output features.
149+
150+
* `out_indices` selects which indices to output
151+
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)
152+
153+
`out_indices` is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most models, index 0 is the stride 2 features, and index 4 is stride 32.
154+
155+
`output_stride` is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support `output_stride=32`.
156+
157+
```python hl_lines="3 4 5"
158+
import torch
159+
import timm
160+
m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
161+
print(f'Feature channels: {m.feature_info.channels()}')
162+
print(f'Feature reduction: {m.feature_info.reduction()}')
163+
o = m(torch.randn(2, 3, 320, 320))
164+
for x in o:
165+
print(x.shape)
166+
```
167+
Output:
168+
```text
169+
Feature channels: [512, 2048]
170+
Feature reduction: [8, 8]
171+
torch.Size([2, 512, 40, 40])
172+
torch.Size([2, 2048, 40, 40])
173+
```

0 commit comments

Comments
 (0)