@@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
229
229
num_elem = batch_size // 2 if half else batch_size
230
230
assert len (output ) == num_elem
231
231
lam_batch , use_cutmix = self ._params_per_elem (num_elem )
232
+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
233
+
232
234
for i in range (num_elem ):
233
235
j = batch_size - i - 1
234
236
lam = lam_batch [i ]
235
237
mixed = batch [i ][0 ]
236
238
if lam != 1. :
237
239
if use_cutmix [i ]:
238
240
if not half :
239
- mixed = mixed .copy ()
241
+ mixed = mixed .copy () if is_np else mixed . clone ()
240
242
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
241
- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
243
+ output .shape ,
244
+ lam ,
245
+ ratio_minmax = self .cutmix_minmax ,
246
+ correct_lam = self .correct_lam ,
247
+ )
242
248
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
243
249
lam_batch [i ] = lam
244
250
else :
245
- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
246
- np .rint (mixed , out = mixed )
247
- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
251
+ if is_np :
252
+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
253
+ np .rint (mixed , out = mixed )
254
+ else :
255
+ mixed = mixed .float () * lam + batch [j ][0 ].float () * (1 - lam )
256
+ torch .round (mixed , out = mixed )
257
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 )) if is_np else mixed .byte ()
248
258
if half :
249
259
lam_batch = np .concatenate ((lam_batch , np .ones (num_elem )))
250
260
return torch .tensor (lam_batch ).unsqueeze (1 )
251
261
252
262
def _mix_pair_collate (self , output , batch ):
253
263
batch_size = len (batch )
254
264
lam_batch , use_cutmix = self ._params_per_elem (batch_size // 2 )
265
+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
266
+
255
267
for i in range (batch_size // 2 ):
256
268
j = batch_size - i - 1
257
269
lam = lam_batch [i ]
@@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
261
273
if lam < 1. :
262
274
if use_cutmix [i ]:
263
275
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
264
- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
265
- patch_i = mixed_i [:, yl :yh , xl :xh ].copy ()
276
+ output .shape ,
277
+ lam ,
278
+ ratio_minmax = self .cutmix_minmax ,
279
+ correct_lam = self .correct_lam ,
280
+ )
281
+ patch_i = mixed_i [:, yl :yh , xl :xh ].copy () if is_np else mixed_i [:, yl :yh , xl :xh ].clone ()
266
282
mixed_i [:, yl :yh , xl :xh ] = mixed_j [:, yl :yh , xl :xh ]
267
283
mixed_j [:, yl :yh , xl :xh ] = patch_i
268
284
lam_batch [i ] = lam
269
285
else :
270
- mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
271
- mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
272
- mixed_i = mixed_temp
273
- np .rint (mixed_j , out = mixed_j )
274
- np .rint (mixed_i , out = mixed_i )
275
- output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
276
- output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
286
+ if is_np :
287
+ mixed_temp = mixed_i .astype (np .float32 ) * lam + mixed_j .astype (np .float32 ) * (1 - lam )
288
+ mixed_j = mixed_j .astype (np .float32 ) * lam + mixed_i .astype (np .float32 ) * (1 - lam )
289
+ mixed_i = mixed_temp
290
+ np .rint (mixed_j , out = mixed_j )
291
+ np .rint (mixed_i , out = mixed_i )
292
+ else :
293
+ mixed_temp = mixed_i .float () * lam + mixed_j .float () * (1 - lam )
294
+ mixed_j = mixed_j .float () * lam + mixed_i .float () * (1 - lam )
295
+ mixed_i = mixed_temp
296
+ torch .round (mixed_j , out = mixed_j )
297
+ torch .round (mixed_i , out = mixed_i )
298
+ output [i ] += torch .from_numpy (mixed_i .astype (np .uint8 )) if is_np else mixed_i .byte ()
299
+ output [j ] += torch .from_numpy (mixed_j .astype (np .uint8 )) if is_np else mixed_j .byte ()
277
300
lam_batch = np .concatenate ((lam_batch , lam_batch [::- 1 ]))
278
301
return torch .tensor (lam_batch ).unsqueeze (1 )
279
302
280
303
def _mix_batch_collate (self , output , batch ):
281
304
batch_size = len (batch )
282
305
lam , use_cutmix = self ._params_per_batch ()
306
+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
307
+
283
308
if use_cutmix :
284
309
(yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
285
- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
310
+ output .shape ,
311
+ lam ,
312
+ ratio_minmax = self .cutmix_minmax ,
313
+ correct_lam = self .correct_lam ,
314
+ )
286
315
for i in range (batch_size ):
287
316
j = batch_size - i - 1
288
317
mixed = batch [i ][0 ]
289
318
if lam != 1. :
290
319
if use_cutmix :
291
- mixed = mixed .copy () # don't want to modify the original while iterating
320
+ mixed = mixed .copy () if is_np else mixed . clone () # don't want to modify the original while iterating
292
321
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
293
322
else :
294
- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
295
- np .rint (mixed , out = mixed )
296
- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
323
+ if is_np :
324
+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
325
+ np .rint (mixed , out = mixed )
326
+ else :
327
+ mixed = mixed .float () * lam + batch [j ][0 ].float () * (1 - lam )
328
+ torch .round (mixed , out = mixed )
329
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 )) if is_np else mixed .byte ()
297
330
return lam
298
331
299
332
def __call__ (self , batch , _ = None ):
0 commit comments