15
15
import torch
16
16
import math
17
17
import numbers
18
- from enum import IntEnum
19
-
20
-
21
- class MixupMode (IntEnum ):
22
- MIXUP = 0
23
- CUTMIX = 1
24
- RANDOM = 2
25
-
26
- @classmethod
27
- def from_str (cls , value ):
28
- return cls [value .upper ()]
29
18
30
19
31
20
def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -50,132 +39,185 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
50
39
return input , target
51
40
52
41
53
- def calc_ratio ( lam , minmax = None ):
42
+ def rand_bbox ( size , lam , border = 0. , count = None ):
54
43
ratio = math .sqrt (1 - lam )
55
- if minmax is not None :
56
- if isinstance (minmax , numbers .Number ):
57
- minmax = (minmax , 1 - minmax )
58
- ratio = np .clip (ratio , minmax [0 ], minmax [1 ])
59
- return ratio
60
-
61
-
62
- def rand_bbox (size , ratio ):
63
- H , W = size [- 2 :]
64
- cut_h , cut_w = int (H * ratio ), int (W * ratio )
65
- cy , cx = np .random .randint (H ), np .random .randint (W )
66
- yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
67
- xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
44
+ img_h , img_w = size [- 2 :]
45
+ cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
46
+ margin_y , margin_x = int (border * cut_h ), int (border * cut_w )
47
+ cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
48
+ cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
49
+ yl = np .clip (cy - cut_h // 2 , 0 , img_h )
50
+ yh = np .clip (cy + cut_h // 2 , 0 , img_h )
51
+ xl = np .clip (cx - cut_w // 2 , 0 , img_w )
52
+ xh = np .clip (cx + cut_w // 2 , 0 , img_w )
68
53
return yl , yh , xl , xh
69
54
70
55
56
+ def rand_bbox_minmax (size , minmax , count = None ):
57
+ assert len (minmax ) == 2
58
+ img_h , img_w = size [- 2 :]
59
+ cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
60
+ cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
61
+ yl = np .random .randint (0 , img_h - cut_h , size = count )
62
+ xl = np .random .randint (0 , img_w - cut_w , size = count )
63
+ yu = yl + cut_h
64
+ xu = xl + cut_w
65
+ return yl , yu , xl , xu
66
+
67
+
68
+ def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
69
+ if ratio_minmax is not None :
70
+ yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
71
+ else :
72
+ yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
73
+ if correct_lam or ratio_minmax is not None :
74
+ bbox_area = (yu - yl ) * (xu - xl )
75
+ lam = 1. - bbox_area / (img_shape [- 2 ] * img_shape [- 1 ])
76
+ return (yl , yu , xl , xu ), lam
77
+
78
+
71
79
def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
72
80
lam = 1.
73
81
if not disable :
74
82
lam = np .random .beta (alpha , alpha )
75
83
if lam != 1 :
76
- yl , yh , xl , xh = rand_bbox (input .size (), calc_ratio ( lam ) )
84
+ yl , yh , xl , xh = rand_bbox (input .size (), lam )
77
85
input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
78
86
if correct_lam :
79
87
lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
80
88
target = mixup_target (target , num_classes , lam , smoothing )
81
89
return input , target
82
90
83
91
84
- def _resolve_mode (mode ):
85
- mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
86
- if mode == MixupMode .RANDOM :
87
- mode = MixupMode (np .random .rand () > 0.7 )
88
- return mode # will be one of cutmix or mixup
89
-
90
-
91
92
def mix_batch (
92
- input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
93
- mode = _resolve_mode (mode )
94
- if mode == MixupMode .CUTMIX :
95
- return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
93
+ input , target , mixup_alpha = 0.2 , cutmix_alpha = 0. , prob = 1.0 , switch_prob = .5 ,
94
+ num_classes = 1000 , smoothing = 0.1 , disable = False ):
95
+ # FIXME test this version
96
+ if np .random .rand () > prob :
97
+ return input , target
98
+ use_cutmix = cutmix_alpha > 0. and np .random .rand () <= switch_prob
99
+ if use_cutmix :
100
+ return cutmix_batch (input , target , cutmix_alpha , num_classes , smoothing , disable )
96
101
else :
97
- return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
102
+ return mixup_batch (input , target , mixup_alpha , num_classes , smoothing , disable )
98
103
99
104
100
105
class FastCollateMixup :
101
- """Fast Collate Mixup that applies different params to each element + flipped pair
106
+ """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
102
107
103
108
NOTE once experiments are done, one of the three variants will remain with this class name
109
+
104
110
"""
105
- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
111
+ def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
112
+ elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
113
+ """
114
+
115
+ Args:
116
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118
+ cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119
+ prob (float): probability of applying mixup or cutmix per batch or element
120
+ switch_prob (float): probability of using cutmix instead of mixup when both active
121
+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122
+ label_smoothing (float):
123
+ num_classes (int):
124
+ """
106
125
self .mixup_alpha = mixup_alpha
126
+ self .cutmix_alpha = cutmix_alpha
127
+ self .cutmix_minmax = cutmix_minmax
128
+ if self .cutmix_minmax is not None :
129
+ assert len (self .cutmix_minmax ) == 2
130
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131
+ self .cutmix_alpha = 1.0
132
+ self .prob = prob
133
+ self .switch_prob = switch_prob
107
134
self .label_smoothing = label_smoothing
108
135
self .num_classes = num_classes
109
- self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
110
- self .mixup_enabled = True
111
- self .correct_lam = True # correct lambda based on clipped area for cutmix
112
- self .ratio_minmax = None # (0.2, 0.8)
136
+ self .elementwise = elementwise
137
+ self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138
+ self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
113
139
114
- def _do_mix (self , tensor , batch ):
140
+ def _mix_elem (self , output , batch ):
115
141
batch_size = len (batch )
116
- lam_out = torch .ones (batch_size )
142
+ lam_out = np .ones (batch_size )
143
+ use_cutmix = np .zeros (batch_size ).astype (np .bool )
144
+ if self .mixup_enabled :
145
+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146
+ use_cutmix = np .random .rand (batch_size ) < self .switch_prob
147
+ lam_mix = np .where (
148
+ use_cutmix ,
149
+ np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size ),
150
+ np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size ))
151
+ elif self .mixup_alpha > 0. :
152
+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
153
+ elif self .cutmix_alpha > 0. :
154
+ use_cutmix = np .ones (batch_size ).astype (np .bool )
155
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156
+ else :
157
+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158
+ lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix , lam_out )
159
+
117
160
for i in range (batch_size ):
118
161
j = batch_size - i - 1
119
- lam = 1.
120
- if self .mixup_enabled :
121
- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
122
-
123
- if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
124
- mixed = batch [i ][0 ].astype (np .float32 )
125
- if lam != 1 :
126
- ratio = calc_ratio (lam )
127
- yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
162
+ lam = lam_out [i ]
163
+ mixed = batch [i ][0 ].astype (np .float32 )
164
+ if lam != 1. :
165
+ if use_cutmix [i ]:
166
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
167
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
128
168
mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
129
- if self .correct_lam :
130
- lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
131
- else :
132
- lam_out [i ] = lam
169
+ lam_out [i ] = lam
170
+ else :
171
+ mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172
+ lam_out [i ] = lam
173
+ np .round (mixed , out = mixed )
174
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
175
+ return torch .tensor (lam_out ).unsqueeze (1 )
176
+
177
+ def _mix_batch (self , output , batch ):
178
+ batch_size = len (batch )
179
+ lam = 1.
180
+ use_cutmix = False
181
+ if self .mixup_enabled and np .random .rand () < self .prob :
182
+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
183
+ use_cutmix = np .random .rand () < self .switch_prob
184
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
185
+ np .random .beta (self .mixup_alpha , self .mixup_alpha )
186
+ elif self .mixup_alpha > 0. :
187
+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha )
188
+ elif self .cutmix_alpha > 0. :
189
+ use_cutmix = True
190
+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
133
191
else :
134
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
135
- lam_out [i ] = lam
136
- np .round (mixed , out = mixed )
137
- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
138
- return lam_out .unsqueeze (1 )
192
+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193
+ lam = lam_mix
194
+
195
+ if use_cutmix :
196
+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
197
+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
198
+
199
+ for i in range (batch_size ):
200
+ j = batch_size - i - 1
201
+ mixed = batch [i ][0 ].astype (np .float32 )
202
+ if lam != 1. :
203
+ if use_cutmix :
204
+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
205
+ else :
206
+ mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
207
+ np .round (mixed , out = mixed )
208
+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
209
+ return lam
139
210
140
211
def __call__ (self , batch ):
141
212
batch_size = len (batch )
142
213
assert batch_size % 2 == 0 , 'Batch size should be even when using this'
143
- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
144
- lam = self ._do_mix (tensor , batch )
214
+ output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
215
+ if self .elementwise :
216
+ lam = self ._mix_elem (output , batch )
217
+ else :
218
+ lam = self ._mix_batch (output , batch )
145
219
target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
146
220
target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
147
221
148
- return tensor , target
149
-
150
-
151
- class FastCollateMixupBatchwise (FastCollateMixup ):
152
- """Fast Collate Mixup that applies same params to whole batch
153
-
154
- NOTE this is for experimentation, may remove at some point
155
- """
156
-
157
- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
158
- super (FastCollateMixupBatchwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
222
+ return output , target
159
223
160
- def _do_mix (self , tensor , batch ):
161
- batch_size = len (batch )
162
- lam = 1.
163
- cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
164
- if self .mixup_enabled :
165
- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
166
- if cutmix :
167
- yl , yh , xl , xh = rand_bbox (batch [0 ][0 ].shape , calc_ratio (lam ))
168
- if self .correct_lam :
169
- lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
170
-
171
- for i in range (batch_size ):
172
- j = batch_size - i - 1
173
- if cutmix :
174
- mixed = batch [i ][0 ].astype (np .float32 )
175
- if lam != 1 :
176
- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
177
- else :
178
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
179
- np .round (mixed , out = mixed )
180
- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
181
- return lam
0 commit comments