Skip to content

Commit b3cb5f3

Browse files
committed
Working on CutMix impl as per #8, integrating with Mixup, currently experimenting...
1 parent 569419b commit b3cb5f3

File tree

4 files changed

+183
-16
lines changed

4 files changed

+183
-16
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from .transforms import *
55
from .loader import create_loader
66
from .transforms_factory import create_transform
7-
from .mixup import mixup_batch, FastCollateMixup
7+
from .mixup import mix_batch, FastCollateMixup, FastCollateMixupBatchwise, FastCollateMixupElementwise
88
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
99
rand_augment_transform, auto_augment_transform

timm/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __getitem__(self, index):
8989
return img, target
9090

9191
def __len__(self):
92-
return len(self.imgs)
92+
return len(self.samples)
9393

9494
def filenames(self, indices=[], basename=False):
9595
if indices:

timm/data/mixup.py

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
1+
""" Mixup and Cutmix
2+
3+
Papers:
4+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5+
6+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7+
8+
Code Reference:
9+
CutMix: https://github.com/clovaai/CutMix-PyTorch
10+
11+
Hacked together by Ross Wightman
12+
"""
13+
114
import numpy as np
215
import torch
16+
import math
17+
from enum import IntEnum
18+
19+
20+
class MixupMode(IntEnum):
21+
MIXUP = 0
22+
CUTMIX = 1
23+
RANDOM = 2
24+
25+
@classmethod
26+
def from_str(cls, value):
27+
return cls[value.upper()]
328

429

530
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
@@ -12,7 +37,7 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
1237
on_value = 1. - smoothing + off_value
1338
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
1439
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
15-
return lam*y1 + (1. - lam)*y2
40+
return y1 * lam + y2 * (1. - lam)
1641

1742

1843
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
@@ -24,28 +49,167 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
2449
return input, target
2550

2651

52+
def rand_bbox(size, ratio):
53+
H, W = size[-2:]
54+
ratio = max(min(ratio, 0.8), 0.2)
55+
cut_h, cut_w = int(H * ratio), int(W * ratio)
56+
cy, cx = np.random.randint(H), np.random.randint(W)
57+
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
58+
xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W)
59+
return yl, yh, xl, xh
60+
61+
62+
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
63+
lam = 1.
64+
if not disable:
65+
lam = np.random.beta(alpha, alpha)
66+
if lam != 1:
67+
ratio = math.sqrt(1. - lam)
68+
yl, yh, xl, xh = rand_bbox(input.size(), ratio)
69+
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
70+
target = mixup_target(target, num_classes, lam, smoothing)
71+
return input, target
72+
73+
74+
def _resolve_mode(mode):
75+
mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
76+
if mode == MixupMode.RANDOM:
77+
mode = MixupMode(np.random.rand() > 0.5)
78+
return mode # will be one of cutmix or mixup
79+
80+
81+
def mix_batch(
82+
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
83+
mode = _resolve_mode(mode)
84+
if mode == MixupMode.CUTMIX:
85+
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
86+
else:
87+
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
88+
89+
2790
class FastCollateMixup:
91+
"""Fast Collate Mixup that applies different params to each element + flipped pair
2892
29-
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
93+
NOTE once experiments are done, one of the three variants will remain with this class name
94+
"""
95+
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
3096
self.mixup_alpha = mixup_alpha
3197
self.label_smoothing = label_smoothing
3298
self.num_classes = num_classes
99+
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
33100
self.mixup_enabled = True
101+
self.correct_lam = False # correct lambda based on clipped area for cutmix
102+
103+
def _do_mix(self, tensor, batch):
104+
batch_size = len(batch)
105+
lam_out = torch.ones(batch_size)
106+
for i in range(batch_size//2):
107+
j = batch_size - i - 1
108+
lam = 1.
109+
if self.mixup_enabled:
110+
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
111+
112+
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
113+
mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32)
114+
ratio = math.sqrt(1. - lam)
115+
if lam != 1:
116+
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
117+
mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
118+
mixed_j[:, yl:yh, xl:xh] = batch[i][0][:, yl:yh, xl:xh].astype(np.float32)
119+
if self.correct_lam:
120+
lam_corrected = (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
121+
lam_out[i] -= lam_corrected
122+
lam_out[j] -= lam_corrected
123+
else:
124+
lam_out[i] = lam
125+
lam_out[j] = lam
126+
else:
127+
mixed_i = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
128+
mixed_j = batch[j][0].astype(np.float32) * lam + batch[i][0].astype(np.float32) * (1 - lam)
129+
lam_out[i] = lam
130+
lam_out[j] = lam
131+
np.round(mixed_i, out=mixed_i)
132+
np.round(mixed_j, out=mixed_j)
133+
tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8))
134+
tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8))
135+
return lam_out
34136

35137
def __call__(self, batch):
36138
batch_size = len(batch)
139+
assert batch_size % 2 == 0, 'Batch size should be even when using this'
140+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
141+
lam = self._do_mix(tensor, batch)
142+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
143+
target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu')
144+
145+
return tensor, target
146+
147+
148+
class FastCollateMixupElementwise(FastCollateMixup):
149+
"""Fast Collate Mixup that applies different params to each batch element
150+
151+
NOTE this is for experimentation, may remove at some point
152+
"""
153+
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
154+
super(FastCollateMixupElementwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode)
155+
156+
def _do_mix(self, tensor, batch):
157+
batch_size = len(batch)
158+
lam_out = torch.ones(batch_size)
159+
for i in range(batch_size):
160+
lam = 1.
161+
if self.mixup_enabled:
162+
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
163+
164+
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
165+
mixed = batch[i][0].astype(np.float32)
166+
ratio = math.sqrt(1. - lam)
167+
if lam != 1:
168+
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
169+
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
170+
if self.correct_lam:
171+
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
172+
else:
173+
lam_out[i] = lam
174+
else:
175+
mixed = batch[i][0].astype(np.float32) * lam + \
176+
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
177+
lam_out[i] = lam
178+
np.round(mixed, out=mixed)
179+
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
180+
return lam_out
181+
182+
183+
class FastCollateMixupBatchwise(FastCollateMixup):
184+
"""Fast Collate Mixup that applies same params to whole batch
185+
186+
NOTE this is for experimentation, may remove at some point
187+
"""
188+
189+
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
190+
super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode)
191+
192+
def _do_mix(self, tensor, batch):
193+
batch_size = len(batch)
194+
lam_out = torch.ones(batch_size)
37195
lam = 1.
196+
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
38197
if self.mixup_enabled:
39198
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
199+
if cutmix and self.correct_lam:
200+
ratio = math.sqrt(1. - lam)
201+
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio)
202+
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
40203

41-
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
42-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
43-
44-
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
45204
for i in range(batch_size):
46-
mixed = batch[i][0].astype(np.float32) * lam + \
47-
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
205+
if cutmix:
206+
mixed = batch[i][0].astype(np.float32)
207+
if lam != 1:
208+
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
209+
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
210+
else:
211+
mixed = batch[i][0].astype(np.float32) * lam + \
212+
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
48213
np.round(mixed, out=mixed)
49214
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
50-
51-
return tensor, target
215+
return lam

train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from torch.nn.parallel import DistributedDataParallel as DDP
2929
has_apex = False
3030

31-
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
31+
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mix_batch, AugMixDataset,\
32+
FastCollateMixupElementwise, FastCollateMixupBatchwise
3233
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
3334
from timm.utils import *
3435
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@@ -134,6 +135,8 @@
134135
help='Do not random erase first (clean) augmentation split')
135136
parser.add_argument('--mixup', type=float, default=0.0,
136137
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
138+
parser.add_argument('--mixup-mode', type=str, default='mixup',
139+
help='Mixup mode ("mixup", "cutmix", "random", default: "mixup")')
137140
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
138141
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
139142
parser.add_argument('--smoothing', type=float, default=0.1,
@@ -352,7 +355,7 @@ def main():
352355
collate_fn = None
353356
if args.prefetcher and args.mixup > 0:
354357
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
355-
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
358+
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode)
356359

357360
if num_aug_splits > 1:
358361
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
@@ -504,10 +507,10 @@ def train_epoch(
504507
if not args.prefetcher:
505508
input, target = input.cuda(), target.cuda()
506509
if args.mixup > 0.:
507-
input, target = mixup_batch(
510+
input, target = mix_batch(
508511
input, target,
509512
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
510-
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
513+
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode)
511514

512515
output = model(input)
513516

0 commit comments

Comments
 (0)