14
14
import numpy as np
15
15
import torch
16
16
import math
17
+ import numbers
17
18
from enum import IntEnum
18
19
19
20
@@ -49,24 +50,33 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
49
50
return input , target
50
51
51
52
53
+ def calc_ratio (lam , minmax = None ):
54
+ ratio = math .sqrt (1 - lam )
55
+ if minmax is not None :
56
+ if isinstance (minmax , numbers .Number ):
57
+ minmax = (minmax , 1 - minmax )
58
+ ratio = np .clip (ratio , minmax [0 ], minmax [1 ])
59
+ return ratio
60
+
61
+
52
62
def rand_bbox (size , ratio ):
53
63
H , W = size [- 2 :]
54
- ratio = max (min (ratio , 0.8 ), 0.2 )
55
64
cut_h , cut_w = int (H * ratio ), int (W * ratio )
56
65
cy , cx = np .random .randint (H ), np .random .randint (W )
57
66
yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
58
67
xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
59
68
return yl , yh , xl , xh
60
69
61
70
62
- def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
71
+ def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
63
72
lam = 1.
64
73
if not disable :
65
74
lam = np .random .beta (alpha , alpha )
66
75
if lam != 1 :
67
- ratio = math .sqrt (1. - lam )
68
- yl , yh , xl , xh = rand_bbox (input .size (), ratio )
76
+ yl , yh , xl , xh = rand_bbox (input .size (), calc_ratio (lam ))
69
77
input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
78
+ if correct_lam :
79
+ lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
70
80
target = mixup_target (target , num_classes , lam , smoothing )
71
81
return input , target
72
82
@@ -82,9 +92,9 @@ def mix_batch(
82
92
input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
83
93
mode = _resolve_mode (mode )
84
94
if mode == MixupMode .CUTMIX :
85
- return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
86
- else :
87
95
return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
96
+ else :
97
+ return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
88
98
89
99
90
100
class FastCollateMixup :
@@ -99,6 +109,7 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
99
109
self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
100
110
self .mixup_enabled = True
101
111
self .correct_lam = False # correct lambda based on clipped area for cutmix
112
+ self .ratio_minmax = None # (0.2, 0.8)
102
113
103
114
def _do_mix (self , tensor , batch ):
104
115
batch_size = len (batch )
@@ -111,7 +122,7 @@ def _do_mix(self, tensor, batch):
111
122
112
123
if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
113
124
mixed_i , mixed_j = batch [i ][0 ].astype (np .float32 ), batch [j ][0 ].astype (np .float32 )
114
- ratio = math . sqrt ( 1. - lam )
125
+ ratio = calc_ratio ( lam , self . ratio_minmax )
115
126
if lam != 1 :
116
127
yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
117
128
mixed_i [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
@@ -132,15 +143,15 @@ def _do_mix(self, tensor, batch):
132
143
np .round (mixed_j , out = mixed_j )
133
144
tensor [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
134
145
tensor [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
135
- return lam_out
146
+ return lam_out . unsqueeze ( 1 )
136
147
137
148
def __call__ (self , batch ):
138
149
batch_size = len (batch )
139
150
assert batch_size % 2 == 0 , 'Batch size should be even when using this'
140
151
tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
141
152
lam = self ._do_mix (tensor , batch )
142
153
target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
143
- target = mixup_target (target , self .num_classes , lam . unsqueeze ( 1 ) , self .label_smoothing , device = 'cpu' )
154
+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
144
155
145
156
return tensor , target
146
157
@@ -157,27 +168,27 @@ def _do_mix(self, tensor, batch):
157
168
batch_size = len (batch )
158
169
lam_out = torch .ones (batch_size )
159
170
for i in range (batch_size ):
171
+ j = batch_size - i - 1
160
172
lam = 1.
161
173
if self .mixup_enabled :
162
174
lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
163
175
164
176
if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
165
177
mixed = batch [i ][0 ].astype (np .float32 )
166
- ratio = math .sqrt (1. - lam )
167
178
if lam != 1 :
179
+ ratio = calc_ratio (lam )
168
180
yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
169
- mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
181
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
170
182
if self .correct_lam :
171
183
lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
172
184
else :
173
185
lam_out [i ] = lam
174
186
else :
175
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
176
- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
187
+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
177
188
lam_out [i ] = lam
178
189
np .round (mixed , out = mixed )
179
190
tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
180
- return lam_out
191
+ return lam_out . unsqueeze ( 1 )
181
192
182
193
183
194
class FastCollateMixupBatchwise (FastCollateMixup ):
@@ -191,25 +202,23 @@ def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=M
191
202
192
203
def _do_mix (self , tensor , batch ):
193
204
batch_size = len (batch )
194
- lam_out = torch .ones (batch_size )
195
205
lam = 1.
196
206
cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
197
207
if self .mixup_enabled :
198
208
lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
199
- if cutmix and self . correct_lam :
200
- ratio = math . sqrt ( 1. - lam )
201
- yl , yh , xl , xh = rand_bbox ( batch [ 0 ][ 0 ]. shape , ratio )
202
- lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
209
+ if cutmix :
210
+ yl , yh , xl , xh = rand_bbox ( batch [ 0 ][ 0 ]. shape , calc_ratio ( lam ) )
211
+ if self . correct_lam :
212
+ lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
203
213
204
214
for i in range (batch_size ):
215
+ j = batch_size - i - 1
205
216
if cutmix :
206
217
mixed = batch [i ][0 ].astype (np .float32 )
207
218
if lam != 1 :
208
- mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
209
- lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
219
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
210
220
else :
211
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
212
- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
221
+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
213
222
np .round (mixed , out = mixed )
214
223
tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
215
224
return lam
0 commit comments