Skip to content

Commit 47a7b3b

Browse files
committed
More flexible mixup mode, add 'half' mode.
1 parent 532e3b4 commit 47a7b3b

File tree

2 files changed

+79
-16
lines changed

2 files changed

+79
-16
lines changed

timm/data/mixup.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ class Mixup:
9696
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
9797
prob (float): probability of applying mixup or cutmix per batch or element
9898
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99-
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
99+
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100100
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101101
label_smoothing (float): apply label smoothing to the mixed target tensor
102102
num_classes (int): number of classes for target
103103
"""
104104
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105-
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
105+
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106106
self.mixup_alpha = mixup_alpha
107107
self.cutmix_alpha = cutmix_alpha
108108
self.cutmix_minmax = cutmix_minmax
@@ -114,7 +114,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
114114
self.switch_prob = switch_prob
115115
self.label_smoothing = label_smoothing
116116
self.num_classes = num_classes
117-
self.elementwise = elementwise
117+
self.mode = mode
118118
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119119
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120120

@@ -173,6 +173,26 @@ def _mix_elem(self, x):
173173
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174174
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175175

176+
def _mix_pair(self, x):
177+
batch_size = len(x)
178+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179+
x_orig = x.clone() # need to keep an unmodified original for mixing source
180+
for i in range(batch_size // 2):
181+
j = batch_size - i - 1
182+
lam = lam_batch[i]
183+
if lam != 1.:
184+
if use_cutmix[i]:
185+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188+
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189+
lam_batch[i] = lam
190+
else:
191+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192+
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195+
176196
def _mix_batch(self, x):
177197
lam, use_cutmix = self._params_per_batch()
178198
if lam == 1.:
@@ -188,7 +208,12 @@ def _mix_batch(self, x):
188208

189209
def __call__(self, x, target):
190210
assert len(x) % 2 == 0, 'Batch size should be even when using this'
191-
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
211+
if self.mode == 'elem':
212+
lam = self._mix_elem(x)
213+
elif self.mode == 'pair':
214+
lam = self._mix_pair(x)
215+
else:
216+
lam = self._mix_batch(x)
192217
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
193218
return x, target
194219

@@ -199,25 +224,57 @@ class FastCollateMixup(Mixup):
199224
A Mixup impl that's performed while collating the batches.
200225
"""
201226

202-
def _mix_elem_collate(self, output, batch):
227+
def _mix_elem_collate(self, output, batch, half=False):
203228
batch_size = len(batch)
204-
lam_batch, use_cutmix = self._params_per_elem(batch_size)
205-
for i in range(batch_size):
229+
num_elem = batch_size // 2 if half else batch_size
230+
assert len(output) == num_elem
231+
lam_batch, use_cutmix = self._params_per_elem(num_elem)
232+
for i in range(num_elem):
206233
j = batch_size - i - 1
207234
lam = lam_batch[i]
208235
mixed = batch[i][0]
209236
if lam != 1.:
210237
if use_cutmix[i]:
211-
mixed = mixed.copy()
238+
if not half:
239+
mixed = mixed.copy()
212240
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
213241
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
214242
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
215243
lam_batch[i] = lam
216244
else:
217245
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
218-
lam_batch[i] = lam
219-
np.round(mixed, out=mixed)
246+
np.rint(mixed, out=mixed)
220247
output[i] += torch.from_numpy(mixed.astype(np.uint8))
248+
if half:
249+
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250+
return torch.tensor(lam_batch).unsqueeze(1)
251+
252+
def _mix_pair_collate(self, output, batch):
253+
batch_size = len(batch)
254+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255+
for i in range(batch_size // 2):
256+
j = batch_size - i - 1
257+
lam = lam_batch[i]
258+
mixed_i = batch[i][0]
259+
mixed_j = batch[j][0]
260+
assert 0 <= lam <= 1.0
261+
if lam < 1.:
262+
if use_cutmix[i]:
263+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265+
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266+
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267+
mixed_j[:, yl:yh, xl:xh] = patch_i
268+
lam_batch[i] = lam
269+
else:
270+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272+
mixed_i = mixed_temp
273+
np.rint(mixed_j, out=mixed_j)
274+
np.rint(mixed_i, out=mixed_i)
275+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
221278
return torch.tensor(lam_batch).unsqueeze(1)
222279

223280
def _mix_batch_collate(self, output, batch):
@@ -235,19 +292,25 @@ def _mix_batch_collate(self, output, batch):
235292
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
236293
else:
237294
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
238-
np.round(mixed, out=mixed)
295+
np.rint(mixed, out=mixed)
239296
output[i] += torch.from_numpy(mixed.astype(np.uint8))
240297
return lam
241298

242299
def __call__(self, batch, _=None):
243300
batch_size = len(batch)
244301
assert batch_size % 2 == 0, 'Batch size should be even when using this'
302+
half = 'half' in self.mode
303+
if half:
304+
batch_size //= 2
245305
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
246-
if self.elementwise:
247-
lam = self._mix_elem_collate(output, batch)
306+
if self.mode == 'elem' or self.mode == 'half':
307+
lam = self._mix_elem_collate(output, batch, half=half)
308+
elif self.mode == 'pair':
309+
lam = self._mix_pair_collate(output, batch)
248310
else:
249311
lam = self._mix_batch_collate(output, batch)
250312
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
251313
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314+
target = target[:batch_size]
252315
return output, target
253316

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@
176176
help='Probability of performing mixup or cutmix when either/both is enabled')
177177
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
178178
help='Probability of switching to cutmix when both mixup and cutmix enabled')
179-
parser.add_argument('--mixup-elem', action='store_true', default=False,
180-
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
179+
parser.add_argument('--mixup-mode', type=str, default='batch',
180+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
181181
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
182182
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
183183
parser.add_argument('--smoothing', type=float, default=0.1,
@@ -444,7 +444,7 @@ def main():
444444
if mixup_active:
445445
mixup_args = dict(
446446
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
447-
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
447+
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
448448
label_smoothing=args.smoothing, num_classes=args.num_classes)
449449
if args.prefetcher:
450450
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)

0 commit comments

Comments
 (0)