@@ -96,13 +96,13 @@ class Mixup:
96
96
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
97
prob (float): probability of applying mixup or cutmix per batch or element
98
98
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
- elementwise (bool ): apply mixup/cutmix params per batch element instead of per batch
99
+ mode (str ): how to apply mixup/cutmix params ( per ' batch', 'pair' (pair of elements), 'elem' (element)
100
100
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
101
label_smoothing (float): apply label smoothing to the mixed target tensor
102
102
num_classes (int): number of classes for target
103
103
"""
104
104
def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
105
- elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
105
+ mode = 'batch' , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
106
106
self .mixup_alpha = mixup_alpha
107
107
self .cutmix_alpha = cutmix_alpha
108
108
self .cutmix_minmax = cutmix_minmax
@@ -114,7 +114,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
114
114
self .switch_prob = switch_prob
115
115
self .label_smoothing = label_smoothing
116
116
self .num_classes = num_classes
117
- self .elementwise = elementwise
117
+ self .mode = mode
118
118
self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119
119
self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120
120
@@ -173,6 +173,26 @@ def _mix_elem(self, x):
173
173
x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174
174
return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175
175
176
+ def _mix_pair (self , x ):
177
+ batch_size = len (x )
178
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
179
+ x_orig = x .clone () # need to keep an unmodified original for mixing source
180
+ for i in range (batch_size // 2 ):
181
+ j = batch_size - i - 1
182
+ lam = lam_batch [i ]
183
+ if lam != 1. :
184
+ if use_cutmix [i ]:
185
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
186
+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
187
+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
188
+ x [j ][:, yl :yh , xl :xh ] = x_orig [i ][:, yl :yh , xl :xh ]
189
+ lam_batch [i ] = lam
190
+ else :
191
+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
192
+ x [j ] = x [j ] * lam + x_orig [i ] * (1 - lam )
193
+ lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
194
+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
195
+
176
196
def _mix_batch (self , x ):
177
197
lam , use_cutmix = self ._params_per_batch ()
178
198
if lam == 1. :
@@ -188,7 +208,12 @@ def _mix_batch(self, x):
188
208
189
209
def __call__ (self , x , target ):
190
210
assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191
- lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
211
+ if self .mode == 'elem' :
212
+ lam = self ._mix_elem (x )
213
+ elif self .mode == 'pair' :
214
+ lam = self ._mix_pair (x )
215
+ else :
216
+ lam = self ._mix_batch (x )
192
217
target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193
218
return x , target
194
219
@@ -199,25 +224,57 @@ class FastCollateMixup(Mixup):
199
224
A Mixup impl that's performed while collating the batches.
200
225
"""
201
226
202
- def _mix_elem_collate (self , output , batch ):
227
+ def _mix_elem_collate (self , output , batch , half = False ):
203
228
batch_size = len (batch )
204
- lam_batch , use_cutmix = self ._params_per_elem (batch_size )
205
- for i in range (batch_size ):
229
+ num_elem = batch_size // 2 if half else batch_size
230
+ assert len (output ) == num_elem
231
+ lam_batch , use_cutmix = self ._params_per_elem (num_elem )
232
+ for i in range (num_elem ):
206
233
j = batch_size - i - 1
207
234
lam = lam_batch [i ]
208
235
mixed = batch [i ][0 ]
209
236
if lam != 1. :
210
237
if use_cutmix [i ]:
211
- mixed = mixed .copy ()
238
+ if not half :
239
+ mixed = mixed .copy ()
212
240
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213
241
output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214
242
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215
243
lam_batch [i ] = lam
216
244
else :
217
245
mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218
- lam_batch [i ] = lam
219
- np .round (mixed , out = mixed )
246
+ np .rint (mixed , out = mixed )
220
247
output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
248
+ if half :
249
+ lam_batch = np .concatenate ((lam_batch , np .ones (num_elem )))
250
+ return torch .tensor (lam_batch ).unsqueeze (1 )
251
+
252
+ def _mix_pair_collate (self , output , batch ):
253
+ batch_size = len (batch )
254
+ lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
255
+ for i in range (batch_size // 2 ):
256
+ j = batch_size - i - 1
257
+ lam = lam_batch [i ]
258
+ mixed_i = batch [i ][0 ]
259
+ mixed_j = batch [j ][0 ]
260
+ assert 0 <= lam <= 1.0
261
+ if lam < 1. :
262
+ if use_cutmix [i ]:
263
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
264
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
265
+ patch_i = mixed_i [:, yl :yh , xl :xh ].copy ()
266
+ mixed_i [:, yl :yh , xl :xh ] = mixed_j [:, yl :yh , xl :xh ]
267
+ mixed_j [:, yl :yh , xl :xh ] = patch_i
268
+ lam_batch [i ] = lam
269
+ else :
270
+ mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
271
+ mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
272
+ mixed_i = mixed_temp
273
+ np .rint (mixed_j , out = mixed_j )
274
+ np .rint (mixed_i , out = mixed_i )
275
+ output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
276
+ output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
277
+ lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
221
278
return torch .tensor (lam_batch ).unsqueeze (1 )
222
279
223
280
def _mix_batch_collate (self , output , batch ):
@@ -235,19 +292,25 @@ def _mix_batch_collate(self, output, batch):
235
292
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
236
293
else :
237
294
mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
238
- np .round (mixed , out = mixed )
295
+ np .rint (mixed , out = mixed )
239
296
output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
240
297
return lam
241
298
242
299
def __call__ (self , batch , _ = None ):
243
300
batch_size = len (batch )
244
301
assert batch_size % 2 == 0 , 'Batch size should be even when using this'
302
+ half = 'half' in self .mode
303
+ if half :
304
+ batch_size //= 2
245
305
output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
246
- if self .elementwise :
247
- lam = self ._mix_elem_collate (output , batch )
306
+ if self .mode == 'elem' or self .mode == 'half' :
307
+ lam = self ._mix_elem_collate (output , batch , half = half )
308
+ elif self .mode == 'pair' :
309
+ lam = self ._mix_pair_collate (output , batch )
248
310
else :
249
311
lam = self ._mix_batch_collate (output , batch )
250
312
target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
251
313
target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
314
+ target = target [:batch_size ]
252
315
return output , target
253
316
0 commit comments