Skip to content

Commit 6949784

Browse files
committed
#9 : code changes for imagenet(resnet) reproducibility
1 parent dab97f8 commit 6949784

File tree

9 files changed

+90
-19
lines changed

9 files changed

+90
-19
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ $ python RandAugment/train.py -c confs/wresnet28x10_cifar10_b256.yaml --save cif
6161

6262
### ImageNet Classification
6363

64+
I have experienced some difficulties while reproducing paper's result.
65+
66+
**Issue : https://github.com/ildoonet/pytorch-randaugment/issues/9**
67+
6468
| Model | Paper's Result | Ours |
6569
|-------------------|---------------:|-------------:|
6670
| ResNet-50 | 77.6 / 92.8 | TODO

RandAugment/augmentations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,13 @@ def augment_list(): # 16 oeprations and their ranges
160160
(Identity, 0., 1.0),
161161
(ShearX, 0., 0.3), # 0
162162
(ShearY, 0., 0.3), # 1
163-
(TranslateX, 0., 0.45), # 2
164-
(TranslateY, 0., 0.45), # 3
163+
(TranslateX, 0., 0.33), # 2
164+
(TranslateY, 0., 0.33), # 3
165165
(Rotate, 0, 30), # 4
166166
(AutoContrast, 0, 1), # 5
167167
(Invert, 0, 1), # 6
168168
(Equalize, 0, 1), # 7
169-
(Solarize, 0, 256), # 8
169+
(Solarize, 0, 110), # 8
170170
(Posterize, 4, 8), # 9
171171
# (Contrast, 0.1, 1.9), # 10
172172
(Color, 0.1, 1.9), # 11

RandAugment/metrics.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,6 @@ def accuracy(output, target, topk=(1,)):
2222
return res
2323

2424

25-
def cross_entropy_smooth(input, target, size_average=True, label_smoothing=0.1):
26-
y = torch.eye(10).cuda()
27-
lb_oh = y[target]
28-
29-
target = lb_oh * (1 - label_smoothing) + 0.5 * label_smoothing
30-
31-
logsoftmax = nn.LogSoftmax()
32-
if size_average:
33-
return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
34-
else:
35-
return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
36-
37-
3825
class Accumulator:
3926
def __init__(self):
4027
self.metrics = defaultdict(lambda: 0.)

RandAugment/networks/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __init__(self, dataset, depth, num_classes, bottleneck=False):
106106
self.fc = nn.Linear(64 * block.expansion, num_classes)
107107

108108
elif dataset == 'imagenet':
109-
blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
110-
layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
109+
blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
110+
layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
111111
assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
112112

113113
self.inplanes = 64

RandAugment/smooth_ce.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
import torch
3+
from torch.nn.modules.module import Module
4+
5+
6+
class SmoothCrossEntropyLoss(Module):
7+
def __init__(self, label_smoothing=0.0, size_average=True):
8+
super().__init__()
9+
self.label_smoothing = label_smoothing
10+
self.size_average = size_average
11+
12+
def forward(self, input, target):
13+
if len(target.size()) == 1:
14+
target = torch.nn.functional.one_hot(target, num_classes=input.size(-1))
15+
target = target.float().cuda()
16+
if self.label_smoothing > 0.0:
17+
s_by_c = self.label_smoothing / len(input[0])
18+
smooth = torch.zeros_like(target)
19+
smooth = smooth + s_by_c
20+
target = target * (1. - s_by_c) + smooth
21+
22+
return cross_entropy(input, target, self.size_average)
23+
24+
25+
def cross_entropy(input, target, size_average=True):
26+
""" Cross entropy that accepts soft targets
27+
Args:
28+
pred: predictions for neural network
29+
targets: targets, can be soft
30+
size_average: if false, sum is returned instead of mean
31+
Examples::
32+
input = torch.FloatTensor([[1.1, 2.8, 1.3], [1.1, 2.1, 4.8]])
33+
input = torch.autograd.Variable(out, requires_grad=True)
34+
target = torch.FloatTensor([[0.05, 0.9, 0.05], [0.05, 0.05, 0.9]])
35+
target = torch.autograd.Variable(y1)
36+
loss = cross_entropy(input, target)
37+
loss.backward()
38+
"""
39+
logsoftmax = torch.nn.LogSoftmax(dim=1)
40+
if size_average:
41+
return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
42+
else:
43+
return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))

RandAugment/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from warmup_scheduler import GradualWarmupScheduler
2121

2222
from RandAugment.common import add_filehandler
23+
from RandAugment.smooth_ce import SmoothCrossEntropyLoss
2324

2425
logger = get_logger('RandAugment')
2526
logger.setLevel(logging.INFO)
@@ -94,7 +95,11 @@ def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metr
9495
# create a model & an optimizer
9596
model = get_model(C.get()['model'], num_class(C.get()['dataset']))
9697

97-
criterion = nn.CrossEntropyLoss()
98+
lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0)
99+
if lb_smooth > 0.0:
100+
criterion = SmoothCrossEntropyLoss(lb_smooth)
101+
else:
102+
criterion = nn.CrossEntropyLoss()
98103
if C.get()['optimizer']['type'] == 'sgd':
99104
optimizer = optim.SGD(
100105
model.parameters(),
@@ -106,6 +111,11 @@ def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metr
106111
else:
107112
raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type'])
108113

114+
if C.get()['optimizer'].get('lars', False):
115+
from torchlars import LARS
116+
optimizer = LARS(optimizer)
117+
logger.info('*** LARS Enabled.')
118+
109119
lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
110120
if lr_scheduler_type == 'cosine':
111121
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.)

confs/resnet50_b1024.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ optimizer:
2020
nesterov: True
2121
decay: 0.0001
2222
clip: 0
23+
lars: False
24+
label_smoothing: 0.0

confs/resnet50_b512.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
model:
2+
type: resnet50
3+
dataset: imagenet
4+
aug: randaugment
5+
randaug:
6+
N: 2
7+
M: 9
8+
9+
cutout: 0
10+
batch: 512
11+
epoch: 180 # 270
12+
lr: 0.1
13+
lr_schedule:
14+
type: 'resnet'
15+
warmup:
16+
multiplier: 2
17+
epoch: 3
18+
optimizer:
19+
type: sgd
20+
nesterov: True
21+
decay: 0.0001
22+
clip: 0
23+
lars: False
24+
label_smoothing: 0.0

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ ray
1111
matplotlib
1212
psutil
1313
requests
14+
torchlars

0 commit comments

Comments
 (0)