Skip to content

Commit 670c61b

Browse files
committed
Some cutmix/mixup cleanup/fixes
1 parent b3cb5f3 commit 670c61b

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

timm/data/mixup.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616
import math
17+
import numbers
1718
from enum import IntEnum
1819

1920

@@ -49,24 +50,33 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
4950
return input, target
5051

5152

53+
def calc_ratio(lam, minmax=None):
54+
ratio = math.sqrt(1 - lam)
55+
if minmax is not None:
56+
if isinstance(minmax, numbers.Number):
57+
minmax = (minmax, 1 - minmax)
58+
ratio = np.clip(ratio, minmax[0], minmax[1])
59+
return ratio
60+
61+
5262
def rand_bbox(size, ratio):
5363
H, W = size[-2:]
54-
ratio = max(min(ratio, 0.8), 0.2)
5564
cut_h, cut_w = int(H * ratio), int(W * ratio)
5665
cy, cx = np.random.randint(H), np.random.randint(W)
5766
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
5867
xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W)
5968
return yl, yh, xl, xh
6069

6170

62-
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
71+
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
6372
lam = 1.
6473
if not disable:
6574
lam = np.random.beta(alpha, alpha)
6675
if lam != 1:
67-
ratio = math.sqrt(1. - lam)
68-
yl, yh, xl, xh = rand_bbox(input.size(), ratio)
76+
yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam))
6977
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
78+
if correct_lam:
79+
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
7080
target = mixup_target(target, num_classes, lam, smoothing)
7181
return input, target
7282

@@ -82,9 +92,9 @@ def mix_batch(
8292
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
8393
mode = _resolve_mode(mode)
8494
if mode == MixupMode.CUTMIX:
85-
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
86-
else:
8795
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
96+
else:
97+
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
8898

8999

90100
class FastCollateMixup:
@@ -99,6 +109,7 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
99109
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
100110
self.mixup_enabled = True
101111
self.correct_lam = False # correct lambda based on clipped area for cutmix
112+
self.ratio_minmax = None # (0.2, 0.8)
102113

103114
def _do_mix(self, tensor, batch):
104115
batch_size = len(batch)
@@ -111,7 +122,7 @@ def _do_mix(self, tensor, batch):
111122

112123
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
113124
mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32)
114-
ratio = math.sqrt(1. - lam)
125+
ratio = calc_ratio(lam, self.ratio_minmax)
115126
if lam != 1:
116127
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
117128
mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
@@ -132,15 +143,15 @@ def _do_mix(self, tensor, batch):
132143
np.round(mixed_j, out=mixed_j)
133144
tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8))
134145
tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8))
135-
return lam_out
146+
return lam_out.unsqueeze(1)
136147

137148
def __call__(self, batch):
138149
batch_size = len(batch)
139150
assert batch_size % 2 == 0, 'Batch size should be even when using this'
140151
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
141152
lam = self._do_mix(tensor, batch)
142153
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
143-
target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu')
154+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
144155

145156
return tensor, target
146157

@@ -157,27 +168,27 @@ def _do_mix(self, tensor, batch):
157168
batch_size = len(batch)
158169
lam_out = torch.ones(batch_size)
159170
for i in range(batch_size):
171+
j = batch_size - i - 1
160172
lam = 1.
161173
if self.mixup_enabled:
162174
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
163175

164176
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
165177
mixed = batch[i][0].astype(np.float32)
166-
ratio = math.sqrt(1. - lam)
167178
if lam != 1:
179+
ratio = calc_ratio(lam)
168180
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
169-
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
181+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
170182
if self.correct_lam:
171183
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
172184
else:
173185
lam_out[i] = lam
174186
else:
175-
mixed = batch[i][0].astype(np.float32) * lam + \
176-
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
187+
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
177188
lam_out[i] = lam
178189
np.round(mixed, out=mixed)
179190
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
180-
return lam_out
191+
return lam_out.unsqueeze(1)
181192

182193

183194
class FastCollateMixupBatchwise(FastCollateMixup):
@@ -191,25 +202,23 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
191202

192203
def _do_mix(self, tensor, batch):
193204
batch_size = len(batch)
194-
lam_out = torch.ones(batch_size)
195205
lam = 1.
196206
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
197207
if self.mixup_enabled:
198208
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
199-
if cutmix and self.correct_lam:
200-
ratio = math.sqrt(1. - lam)
201-
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio)
202-
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
209+
if cutmix:
210+
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam))
211+
if self.correct_lam:
212+
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
203213

204214
for i in range(batch_size):
215+
j = batch_size - i - 1
205216
if cutmix:
206217
mixed = batch[i][0].astype(np.float32)
207218
if lam != 1:
208-
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
209-
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
219+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
210220
else:
211-
mixed = batch[i][0].astype(np.float32) * lam + \
212-
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
221+
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
213222
np.round(mixed, out=mixed)
214223
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
215224
return lam

0 commit comments

Comments
 (0)