Skip to content

Commit f471c17

Browse files
committed
More cutmix/mixup overhaul, ready to kick-off some trials.
1 parent 92f2d0d commit f471c17

File tree

2 files changed

+159
-105
lines changed

2 files changed

+159
-105
lines changed

timm/data/mixup.py

Lines changed: 141 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,6 @@
1515
import torch
1616
import math
1717
import numbers
18-
from enum import IntEnum
19-
20-
21-
class MixupMode(IntEnum):
22-
MIXUP = 0
23-
CUTMIX = 1
24-
RANDOM = 2
25-
26-
@classmethod
27-
def from_str(cls, value):
28-
return cls[value.upper()]
2918

3019

3120
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
@@ -50,132 +39,185 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
5039
return input, target
5140

5241

53-
def calc_ratio(lam, minmax=None):
42+
def rand_bbox(size, lam, border=0., count=None):
5443
ratio = math.sqrt(1 - lam)
55-
if minmax is not None:
56-
if isinstance(minmax, numbers.Number):
57-
minmax = (minmax, 1 - minmax)
58-
ratio = np.clip(ratio, minmax[0], minmax[1])
59-
return ratio
60-
61-
62-
def rand_bbox(size, ratio):
63-
H, W = size[-2:]
64-
cut_h, cut_w = int(H * ratio), int(W * ratio)
65-
cy, cx = np.random.randint(H), np.random.randint(W)
66-
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
67-
xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W)
44+
img_h, img_w = size[-2:]
45+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
46+
margin_y, margin_x = int(border * cut_h), int(border * cut_w)
47+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
48+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
49+
yl = np.clip(cy - cut_h // 2, 0, img_h)
50+
yh = np.clip(cy + cut_h // 2, 0, img_h)
51+
xl = np.clip(cx - cut_w // 2, 0, img_w)
52+
xh = np.clip(cx + cut_w // 2, 0, img_w)
6853
return yl, yh, xl, xh
6954

7055

56+
def rand_bbox_minmax(size, minmax, count=None):
57+
assert len(minmax) == 2
58+
img_h, img_w = size[-2:]
59+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
60+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
61+
yl = np.random.randint(0, img_h - cut_h, size=count)
62+
xl = np.random.randint(0, img_w - cut_w, size=count)
63+
yu = yl + cut_h
64+
xu = xl + cut_w
65+
return yl, yu, xl, xu
66+
67+
68+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
69+
if ratio_minmax is not None:
70+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
71+
else:
72+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
73+
if correct_lam or ratio_minmax is not None:
74+
bbox_area = (yu - yl) * (xu - xl)
75+
lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1])
76+
return (yl, yu, xl, xu), lam
77+
78+
7179
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
7280
lam = 1.
7381
if not disable:
7482
lam = np.random.beta(alpha, alpha)
7583
if lam != 1:
76-
yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam))
84+
yl, yh, xl, xh = rand_bbox(input.size(), lam)
7785
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
7886
if correct_lam:
7987
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
8088
target = mixup_target(target, num_classes, lam, smoothing)
8189
return input, target
8290

8391

84-
def _resolve_mode(mode):
85-
mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
86-
if mode == MixupMode.RANDOM:
87-
mode = MixupMode(np.random.rand() > 0.7)
88-
return mode # will be one of cutmix or mixup
89-
90-
9192
def mix_batch(
92-
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
93-
mode = _resolve_mode(mode)
94-
if mode == MixupMode.CUTMIX:
95-
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
93+
input, target, mixup_alpha=0.2, cutmix_alpha=0., prob=1.0, switch_prob=.5,
94+
num_classes=1000, smoothing=0.1, disable=False):
95+
# FIXME test this version
96+
if np.random.rand() > prob:
97+
return input, target
98+
use_cutmix = cutmix_alpha > 0. and np.random.rand() <= switch_prob
99+
if use_cutmix:
100+
return cutmix_batch(input, target, cutmix_alpha, num_classes, smoothing, disable)
96101
else:
97-
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
102+
return mixup_batch(input, target, mixup_alpha, num_classes, smoothing, disable)
98103

99104

100105
class FastCollateMixup:
101-
"""Fast Collate Mixup that applies different params to each element + flipped pair
106+
"""Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
102107
103108
NOTE once experiments are done, one of the three variants will remain with this class name
109+
104110
"""
105-
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
111+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
112+
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
113+
"""
114+
115+
Args:
116+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118+
cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119+
prob (float): probability of applying mixup or cutmix per batch or element
120+
switch_prob (float): probability of using cutmix instead of mixup when both active
121+
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122+
label_smoothing (float):
123+
num_classes (int):
124+
"""
106125
self.mixup_alpha = mixup_alpha
126+
self.cutmix_alpha = cutmix_alpha
127+
self.cutmix_minmax = cutmix_minmax
128+
if self.cutmix_minmax is not None:
129+
assert len(self.cutmix_minmax) == 2
130+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131+
self.cutmix_alpha = 1.0
132+
self.prob = prob
133+
self.switch_prob = switch_prob
107134
self.label_smoothing = label_smoothing
108135
self.num_classes = num_classes
109-
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
110-
self.mixup_enabled = True
111-
self.correct_lam = True # correct lambda based on clipped area for cutmix
112-
self.ratio_minmax = None # (0.2, 0.8)
136+
self.elementwise = elementwise
137+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
113139

114-
def _do_mix(self, tensor, batch):
140+
def _mix_elem(self, output, batch):
115141
batch_size = len(batch)
116-
lam_out = torch.ones(batch_size)
142+
lam_out = np.ones(batch_size)
143+
use_cutmix = np.zeros(batch_size).astype(np.bool)
144+
if self.mixup_enabled:
145+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146+
use_cutmix = np.random.rand(batch_size) < self.switch_prob
147+
lam_mix = np.where(
148+
use_cutmix,
149+
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
150+
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
151+
elif self.mixup_alpha > 0.:
152+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
153+
elif self.cutmix_alpha > 0.:
154+
use_cutmix = np.ones(batch_size).astype(np.bool)
155+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
156+
else:
157+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158+
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out)
159+
117160
for i in range(batch_size):
118161
j = batch_size - i - 1
119-
lam = 1.
120-
if self.mixup_enabled:
121-
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
122-
123-
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
124-
mixed = batch[i][0].astype(np.float32)
125-
if lam != 1:
126-
ratio = calc_ratio(lam)
127-
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
162+
lam = lam_out[i]
163+
mixed = batch[i][0].astype(np.float32)
164+
if lam != 1.:
165+
if use_cutmix[i]:
166+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
167+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
128168
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
129-
if self.correct_lam:
130-
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
131-
else:
132-
lam_out[i] = lam
169+
lam_out[i] = lam
170+
else:
171+
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
172+
lam_out[i] = lam
173+
np.round(mixed, out=mixed)
174+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
175+
return torch.tensor(lam_out).unsqueeze(1)
176+
177+
def _mix_batch(self, output, batch):
178+
batch_size = len(batch)
179+
lam = 1.
180+
use_cutmix = False
181+
if self.mixup_enabled and np.random.rand() < self.prob:
182+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
183+
use_cutmix = np.random.rand() < self.switch_prob
184+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
185+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
186+
elif self.mixup_alpha > 0.:
187+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
188+
elif self.cutmix_alpha > 0.:
189+
use_cutmix = True
190+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
133191
else:
134-
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
135-
lam_out[i] = lam
136-
np.round(mixed, out=mixed)
137-
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
138-
return lam_out.unsqueeze(1)
192+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193+
lam = lam_mix
194+
195+
if use_cutmix:
196+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
197+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
198+
199+
for i in range(batch_size):
200+
j = batch_size - i - 1
201+
mixed = batch[i][0].astype(np.float32)
202+
if lam != 1.:
203+
if use_cutmix:
204+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
205+
else:
206+
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
207+
np.round(mixed, out=mixed)
208+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
209+
return lam
139210

140211
def __call__(self, batch):
141212
batch_size = len(batch)
142213
assert batch_size % 2 == 0, 'Batch size should be even when using this'
143-
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
144-
lam = self._do_mix(tensor, batch)
214+
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
215+
if self.elementwise:
216+
lam = self._mix_elem(output, batch)
217+
else:
218+
lam = self._mix_batch(output, batch)
145219
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
146220
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
147221

148-
return tensor, target
149-
150-
151-
class FastCollateMixupBatchwise(FastCollateMixup):
152-
"""Fast Collate Mixup that applies same params to whole batch
153-
154-
NOTE this is for experimentation, may remove at some point
155-
"""
156-
157-
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
158-
super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode)
222+
return output, target
159223

160-
def _do_mix(self, tensor, batch):
161-
batch_size = len(batch)
162-
lam = 1.
163-
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
164-
if self.mixup_enabled:
165-
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
166-
if cutmix:
167-
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam))
168-
if self.correct_lam:
169-
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
170-
171-
for i in range(batch_size):
172-
j = batch_size - i - 1
173-
if cutmix:
174-
mixed = batch[i][0].astype(np.float32)
175-
if lam != 1:
176-
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
177-
else:
178-
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
179-
np.round(mixed, out=mixed)
180-
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
181-
return lam

train.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,16 @@
157157
help='Do not random erase first (clean) augmentation split')
158158
parser.add_argument('--mixup', type=float, default=0.0,
159159
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
160-
parser.add_argument('--mixup-mode', type=str, default='mixup',
161-
help='Mixup mode. One of "mixup", "cutmix", "random" (default: "mixup")')
160+
parser.add_argument('--cutmix', type=float, default=0.0,
161+
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
162+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
163+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
164+
parser.add_argument('--mixup-prob', type=float, default=1.0,
165+
help='Probability of performing mixup or cutmix when either/both is enabled')
166+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
167+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
168+
parser.add_argument('--mixup-elem', action='store_true', default=False,
169+
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
162170
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
163171
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
164172
parser.add_argument('--smoothing', type=float, default=0.1,
@@ -390,9 +398,12 @@ def main():
390398
dataset_train = Dataset(train_dir)
391399

392400
collate_fn = None
393-
if args.prefetcher and args.mixup > 0:
401+
if args.prefetcher and (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None):
394402
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
395-
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode)
403+
collate_fn = FastCollateMixup(
404+
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
405+
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
406+
label_smoothing=args.smoothing, num_classes=args.num_classes)
396407

397408
if num_aug_splits > 1:
398409
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
@@ -555,8 +566,9 @@ def train_epoch(
555566
if args.mixup > 0.:
556567
input, target = mix_batch(
557568
input, target,
558-
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
559-
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode)
569+
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, prob=args.mixup_prob,
570+
switch_prob=args.mixup_switch_prob, num_classes=args.num_classes, smoothing=args.smoothing,
571+
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
560572

561573
output = model(input)
562574

0 commit comments

Comments
 (0)