Skip to content

Commit 616ef12

Browse files
committed
reproduce results for wresnet28x10 @ cifar10/100.
update usage with setup.py.
1 parent d259a44 commit 616ef12

File tree

8 files changed

+116
-7
lines changed

8 files changed

+116
-7
lines changed

README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,54 @@ Unofficial PyTorch Reimplementation of RandAugment. Most of codes are from [Fast
44

55
## Introduction
66

7+
TODO
8+
79
## Install
810

11+
```bash
12+
$ pip install git+https://github.com/ildoonet/pytorch-randaugment
13+
```
14+
915
## Usage
1016

17+
```python
18+
from torchvision.transforms import transforms
19+
from RandAugment import RandAugment
20+
21+
transform_train = transforms.Compose([
22+
transforms.RandomCrop(32, padding=4),
23+
transforms.RandomHorizontalFlip(),
24+
transforms.ToTensor(),
25+
transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
26+
])
27+
28+
# Add RandAugment with N, M(hyperparameter)
29+
transform_train.transforms.insert(0, RandAugment(N, M))
30+
```
31+
1132
## Experiment
1233

34+
We use same hyperparameters as the paper mentioned.
35+
36+
### CIFAR-10
37+
38+
| Model | Paper's Result | Ours |
39+
|-------------------|---------------:|-------------:|
40+
| Wide-ResNet 28x10 | 97.3 | 97.4 |
41+
| Shake26 2x96d |
42+
| Pyramid272 | TODO |
43+
44+
### CIFAR-100
45+
46+
| Model | Paper's Result | Ours |
47+
|-------------------|---------------:|-------------:|
48+
| Wide-ResNet 28x10 | 83.3 | 83.3 |
49+
50+
### ImageNet
51+
52+
TODO
53+
1354
## References
1455

15-
- RandAugment : https://arxiv.org/abs/1909.13719
56+
- RandAugment : [Paper](https://arxiv.org/abs/1909.13719)
1657
- Fast AutoAugment : [Code](https://github.com/kakaobrain/fast-autoaugment) [Paper](https://arxiv.org/abs/1905.00397)

RandAugment/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
from RandAugment.augmentations import RandAugment

RandAugment/networks/wideresnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import numpy as np
55

66

7+
_bn_momentum = 0.1
8+
9+
710
def conv3x3(in_planes, out_planes, stride=1):
811
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
912

@@ -21,10 +24,10 @@ def conv_init(m):
2124
class WideBasic(nn.Module):
2225
def __init__(self, in_planes, planes, dropout_rate, stride=1):
2326
super(WideBasic, self).__init__()
24-
self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.9)
27+
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
2528
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
2629
self.dropout = nn.Dropout(p=dropout_rate)
27-
self.bn2 = nn.BatchNorm2d(planes, momentum=0.9)
30+
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
2831
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
2932

3033
self.shortcut = nn.Sequential()
@@ -56,7 +59,7 @@ def __init__(self, depth, widen_factor, dropout_rate, num_classes):
5659
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
5760
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
5861
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
59-
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
62+
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
6063
self.linear = nn.Linear(nStages[3], num_classes)
6164

6265
# self.apply(conv_init)

confs/shake26_2x96d_cifar10_b512.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
model:
2+
type: shakeshake26_2x96d
3+
dataset: cifar10
4+
aug: fa_reduced_cifar10
5+
cutout: 16
6+
batch: 512
7+
epoch: 1800
8+
lr: 0.01
9+
lr_schedule:
10+
type: 'cosine'
11+
warmup:
12+
multiplier: 4
13+
epoch: 5
14+
optimizer:
15+
type: sgd
16+
nesterov: True
17+
decay: 0.001

confs/wresnet28x10_cifar100_b512.yaml renamed to confs/wresnet28x10_cifar100_b256.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ randaug:
66
N: 2
77
M: 14
88
cutout: 16
9-
batch: 512
9+
batch: 256
1010
epoch: 200
1111
lr: 0.1
1212
lr_schedule:
1313
type: 'cosine'
1414
warmup:
15-
multiplier: 4
15+
multiplier: 2
1616
epoch: 5
1717
optimizer:
1818
type: sgd

confs/wresnet28x10_cifar10_b256.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
model:
2+
type: wresnet28_10
3+
dataset: cifar10
4+
aug: randaugment
5+
randaug:
6+
N: 3
7+
M: 5 # from appendix
8+
cutout: 16
9+
batch: 256
10+
epoch: 200
11+
lr: 0.1
12+
lr_schedule:
13+
type: 'cosine'
14+
warmup:
15+
multiplier: 2
16+
epoch: 5
17+
optimizer:
18+
type: sgd
19+
nesterov: True
20+
decay: 0.0005

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ sklearn
1010
ray
1111
matplotlib
1212
psutil
13-
requests
13+
requests

setup.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import setuptools
6+
7+
_VERSION = '0.1'
8+
9+
# 'opencv-python >= 3.3.1'
10+
REQUIRED_PACKAGES = [
11+
]
12+
13+
DEPENDENCY_LINKS = [
14+
]
15+
16+
setuptools.setup(
17+
name='RandAugment',
18+
version=_VERSION,
19+
description='Unofficial PyTorch Reimplementation of RandAugment',
20+
install_requires=REQUIRED_PACKAGES,
21+
dependency_links=DEPENDENCY_LINKS,
22+
url='https://github.com/ildoonet/pytorch-randaugment',
23+
license='MIT License',
24+
package_dir={},
25+
packages=setuptools.find_packages(exclude=['tests']),
26+
)

0 commit comments

Comments
 (0)