1
+ """ Mixup and Cutmix
2
+
3
+ Papers:
4
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5
+
6
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7
+
8
+ Code Reference:
9
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
10
+
11
+ Hacked together by Ross Wightman
12
+ """
13
+
1
14
import numpy as np
2
15
import torch
16
+ import math
17
+ from enum import IntEnum
18
+
19
+
20
+ class MixupMode (IntEnum ):
21
+ MIXUP = 0
22
+ CUTMIX = 1
23
+ RANDOM = 2
24
+
25
+ @classmethod
26
+ def from_str (cls , value ):
27
+ return cls [value .upper ()]
3
28
4
29
5
30
def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -12,7 +37,7 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
12
37
on_value = 1. - smoothing + off_value
13
38
y1 = one_hot (target , num_classes , on_value = on_value , off_value = off_value , device = device )
14
39
y2 = one_hot (target .flip (0 ), num_classes , on_value = on_value , off_value = off_value , device = device )
15
- return lam * y1 + (1. - lam )* y2
40
+ return y1 * lam + y2 * (1. - lam )
16
41
17
42
18
43
def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
@@ -24,28 +49,167 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
24
49
return input , target
25
50
26
51
52
+ def rand_bbox (size , ratio ):
53
+ H , W = size [- 2 :]
54
+ ratio = max (min (ratio , 0.8 ), 0.2 )
55
+ cut_h , cut_w = int (H * ratio ), int (W * ratio )
56
+ cy , cx = np .random .randint (H ), np .random .randint (W )
57
+ yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
58
+ xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
59
+ return yl , yh , xl , xh
60
+
61
+
62
+ def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
63
+ lam = 1.
64
+ if not disable :
65
+ lam = np .random .beta (alpha , alpha )
66
+ if lam != 1 :
67
+ ratio = math .sqrt (1. - lam )
68
+ yl , yh , xl , xh = rand_bbox (input .size (), ratio )
69
+ input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
70
+ target = mixup_target (target , num_classes , lam , smoothing )
71
+ return input , target
72
+
73
+
74
+ def _resolve_mode (mode ):
75
+ mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
76
+ if mode == MixupMode .RANDOM :
77
+ mode = MixupMode (np .random .rand () > 0.5 )
78
+ return mode # will be one of cutmix or mixup
79
+
80
+
81
+ def mix_batch (
82
+ input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
83
+ mode = _resolve_mode (mode )
84
+ if mode == MixupMode .CUTMIX :
85
+ return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
86
+ else :
87
+ return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
88
+
89
+
27
90
class FastCollateMixup :
91
+ """Fast Collate Mixup that applies different params to each element + flipped pair
28
92
29
- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 ):
93
+ NOTE once experiments are done, one of the three variants will remain with this class name
94
+ """
95
+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
30
96
self .mixup_alpha = mixup_alpha
31
97
self .label_smoothing = label_smoothing
32
98
self .num_classes = num_classes
99
+ self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
33
100
self .mixup_enabled = True
101
+ self .correct_lam = False # correct lambda based on clipped area for cutmix
102
+
103
+ def _do_mix (self , tensor , batch ):
104
+ batch_size = len (batch )
105
+ lam_out = torch .ones (batch_size )
106
+ for i in range (batch_size // 2 ):
107
+ j = batch_size - i - 1
108
+ lam = 1.
109
+ if self .mixup_enabled :
110
+ lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
111
+
112
+ if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
113
+ mixed_i , mixed_j = batch [i ][0 ].astype (np .float32 ), batch [j ][0 ].astype (np .float32 )
114
+ ratio = math .sqrt (1. - lam )
115
+ if lam != 1 :
116
+ yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
117
+ mixed_i [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
118
+ mixed_j [:, yl :yh , xl :xh ] = batch [i ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
119
+ if self .correct_lam :
120
+ lam_corrected = (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
121
+ lam_out [i ] -= lam_corrected
122
+ lam_out [j ] -= lam_corrected
123
+ else :
124
+ lam_out [i ] = lam
125
+ lam_out [j ] = lam
126
+ else :
127
+ mixed_i = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
128
+ mixed_j = batch [j ][0 ].astype (np .float32 ) * lam + batch [i ][0 ].astype (np .float32 ) * (1 - lam )
129
+ lam_out [i ] = lam
130
+ lam_out [j ] = lam
131
+ np .round (mixed_i , out = mixed_i )
132
+ np .round (mixed_j , out = mixed_j )
133
+ tensor [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
134
+ tensor [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
135
+ return lam_out
34
136
35
137
def __call__ (self , batch ):
36
138
batch_size = len (batch )
139
+ assert batch_size % 2 == 0 , 'Batch size should be even when using this'
140
+ tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
141
+ lam = self ._do_mix (tensor , batch )
142
+ target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
143
+ target = mixup_target (target , self .num_classes , lam .unsqueeze (1 ), self .label_smoothing , device = 'cpu' )
144
+
145
+ return tensor , target
146
+
147
+
148
+ class FastCollateMixupElementwise (FastCollateMixup ):
149
+ """Fast Collate Mixup that applies different params to each batch element
150
+
151
+ NOTE this is for experimentation, may remove at some point
152
+ """
153
+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
154
+ super (FastCollateMixupElementwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
155
+
156
+ def _do_mix (self , tensor , batch ):
157
+ batch_size = len (batch )
158
+ lam_out = torch .ones (batch_size )
159
+ for i in range (batch_size ):
160
+ lam = 1.
161
+ if self .mixup_enabled :
162
+ lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
163
+
164
+ if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
165
+ mixed = batch [i ][0 ].astype (np .float32 )
166
+ ratio = math .sqrt (1. - lam )
167
+ if lam != 1 :
168
+ yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
169
+ mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
170
+ if self .correct_lam :
171
+ lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
172
+ else :
173
+ lam_out [i ] = lam
174
+ else :
175
+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
176
+ batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
177
+ lam_out [i ] = lam
178
+ np .round (mixed , out = mixed )
179
+ tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
180
+ return lam_out
181
+
182
+
183
+ class FastCollateMixupBatchwise (FastCollateMixup ):
184
+ """Fast Collate Mixup that applies same params to whole batch
185
+
186
+ NOTE this is for experimentation, may remove at some point
187
+ """
188
+
189
+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
190
+ super (FastCollateMixupBatchwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
191
+
192
+ def _do_mix (self , tensor , batch ):
193
+ batch_size = len (batch )
194
+ lam_out = torch .ones (batch_size )
37
195
lam = 1.
196
+ cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
38
197
if self .mixup_enabled :
39
198
lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
199
+ if cutmix and self .correct_lam :
200
+ ratio = math .sqrt (1. - lam )
201
+ yl , yh , xl , xh = rand_bbox (batch [0 ][0 ].shape , ratio )
202
+ lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
40
203
41
- target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
42
- target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
43
-
44
- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
45
204
for i in range (batch_size ):
46
- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
47
- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
205
+ if cutmix :
206
+ mixed = batch [i ][0 ].astype (np .float32 )
207
+ if lam != 1 :
208
+ mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
209
+ lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
210
+ else :
211
+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
212
+ batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
48
213
np .round (mixed , out = mixed )
49
214
tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
50
-
51
- return tensor , target
215
+ return lam
0 commit comments