Skip to content

Commit 472e0ac

Browse files
committed
add more configurations.
1 parent 616ef12 commit 472e0ac

File tree

8 files changed

+118
-8
lines changed

8 files changed

+118
-8
lines changed

README.md

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,41 @@ transform_train.transforms.insert(0, RandAugment(N, M))
3131

3232
## Experiment
3333

34-
We use same hyperparameters as the paper mentioned.
34+
We use same hyperparameters as the paper mentioned. We observed similar results as reported.
3535

36-
### CIFAR-10
36+
You can run an experiment with,
37+
38+
```bash
39+
$ python RandAugment/train.py -c confs/wresnet28x10_cifar10_b256.yaml --save cifar10_wres28x10.pth
40+
```
41+
42+
### CIFAR-10 Classification
3743

3844
| Model | Paper's Result | Ours |
3945
|-------------------|---------------:|-------------:|
4046
| Wide-ResNet 28x10 | 97.3 | 97.4 |
41-
| Shake26 2x96d |
42-
| Pyramid272 | TODO |
47+
| Shake26 2x96d | 98.0 | 98.1 |
48+
| Pyramid272 | 98.5 |
4349

44-
### CIFAR-100
50+
### CIFAR-100 Classification
4551

4652
| Model | Paper's Result | Ours |
4753
|-------------------|---------------:|-------------:|
4854
| Wide-ResNet 28x10 | 83.3 | 83.3 |
4955

50-
### ImageNet
56+
### SVHN Classification
5157

52-
TODO
58+
| Model | Paper's Result | Ours |
59+
|-------------------|---------------:|-------------:|
60+
| Wide-ResNet 28x10 | 98.9 | TODO |
61+
62+
### ImageNet Classification
63+
64+
| Model | Paper's Result | Ours |
65+
|-------------------|---------------:|-------------:|
66+
| ResNet-50 | 77.6 / 92.8 | TODO
67+
| EfficientNet-B5 | 83.2 / 96.7 | TODO
68+
| EfficientNet-B7 | 84.4 / 97.1 | TODO
5369

5470
## References
5571

RandAugment/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from PIL import Image
66

77
from torch.utils.data import SubsetRandomSampler, Sampler
8+
from torch.utils.data.dataset import ConcatDataset
89
from torchvision.transforms import transforms
910
from sklearn.model_selection import StratifiedShuffleSplit
1011
from theconf import Config as C
@@ -80,6 +81,11 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0):
8081
elif dataset == 'cifar100':
8182
total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train)
8283
testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test)
84+
elif dataset == 'svhn':
85+
trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
86+
extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train)
87+
total_trainset = ConcatDataset([trainset, extraset])
88+
testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
8389
elif dataset == 'imagenet':
8490
total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
8591
testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)

RandAugment/lr_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def adjust_learning_rate_resnet(optimizer):
1111

1212
if C.get()['epoch'] == 90:
1313
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
14+
elif C.get()['epoch'] == 180:
15+
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 160])
1416
elif C.get()['epoch'] == 270: # autoaugment
1517
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
1618
else:

confs/pyramid272_cifar10_b64.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
model:
2+
type: pyramid
3+
depth: 272
4+
alpha: 200
5+
bottleneck: True
6+
dataset: cifar10
7+
aug: randaugment
8+
randaug:
9+
N: 3
10+
M: 9
11+
12+
cutout: 16
13+
batch: 64
14+
epoch: 1800
15+
lr: 0.05
16+
lr_schedule:
17+
type: 'cosine'
18+
optimizer:
19+
type: sgd
20+
nesterov: True
21+
decay: 0.00005

confs/resnet200_b256.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
model:
2+
type: resnet200
3+
dataset: imagenet
4+
aug: fa_reduced_imagenet
5+
cutout: 0
6+
batch: 256
7+
epoch: 270
8+
lr: 0.05
9+
lr_schedule:
10+
type: 'resnet'
11+
warmup:
12+
multiplier: 2
13+
epoch: 3
14+
optimizer:
15+
type: sgd
16+
nesterov: True
17+
decay: 0.0001
18+
clip: 0

confs/resnet50_b1024.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
model:
2+
type: resnet50
3+
dataset: imagenet
4+
aug: randaugment
5+
randaug:
6+
N: 2
7+
M: 9
8+
9+
cutout: 0
10+
batch: 1024
11+
epoch: 180 # 270
12+
lr: 0.1
13+
lr_schedule:
14+
type: 'resnet'
15+
warmup:
16+
multiplier: 4
17+
epoch: 3
18+
optimizer:
19+
type: sgd
20+
nesterov: True
21+
decay: 0.0001
22+
clip: 0

confs/shake26_2x96d_cifar10_b512.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
model:
22
type: shakeshake26_2x96d
33
dataset: cifar10
4-
aug: fa_reduced_cifar10
4+
aug: randaugment
5+
randaug:
6+
N: 3
7+
M: 9
8+
59
cutout: 16
610
batch: 512
711
epoch: 1800

confs/wresnet28x10_svhn_b256.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
model:
2+
type: wresnet28_10
3+
dataset: svhn
4+
aug: randaugment
5+
randaug:
6+
N: 3
7+
M: 7 # from appendix
8+
9+
cutout: 16
10+
batch: 256
11+
epoch: 160
12+
lr: 0.005
13+
lr_schedule:
14+
type: 'cosine'
15+
warmup:
16+
multiplier: 2
17+
epoch: 5
18+
optimizer:
19+
type: sgd
20+
nesterov: True
21+
decay: 0.001

0 commit comments

Comments
 (0)