@@ -54,8 +54,7 @@ def _interpolation(kwargs):
54
54
interpolation = kwargs .pop ('resample' , _DEFAULT_INTERPOLATION )
55
55
if isinstance (interpolation , (list , tuple )):
56
56
return random .choice (interpolation )
57
- else :
58
- return interpolation
57
+ return interpolation
59
58
60
59
61
60
def _check_args_tf (kwargs ):
@@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
100
99
_check_args_tf (kwargs )
101
100
if _PIL_VER >= (5 , 2 ):
102
101
return img .rotate (degrees , ** kwargs )
103
- elif _PIL_VER >= (5 , 0 ):
102
+ if _PIL_VER >= (5 , 0 ):
104
103
w , h = img .size
105
104
post_trans = (0 , 0 )
106
105
rotn_center = (w / 2.0 , h / 2.0 )
@@ -124,8 +123,7 @@ def transform(x, y, matrix):
124
123
matrix [2 ] += rotn_center [0 ]
125
124
matrix [5 ] += rotn_center [1 ]
126
125
return img .transform (img .size , Image .AFFINE , matrix , ** kwargs )
127
- else :
128
- return img .rotate (degrees , resample = kwargs ['resample' ])
126
+ return img .rotate (degrees , resample = kwargs ['resample' ])
129
127
130
128
131
129
def auto_contrast (img , ** __ ):
@@ -151,12 +149,13 @@ def solarize_add(img, add, thresh=128, **__):
151
149
lut .append (min (255 , i + add ))
152
150
else :
153
151
lut .append (i )
152
+
154
153
if img .mode in ("L" , "RGB" ):
155
154
if img .mode == "RGB" and len (lut ) == 256 :
156
155
lut = lut + lut + lut
157
156
return img .point (lut )
158
- else :
159
- return img
157
+
158
+ return img
160
159
161
160
162
161
def posterize (img , bits_to_keep , ** __ ):
@@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
226
225
227
226
def _minmax_level_to_arg (level , _hparams , min_val = 0. , max_val = 1.0 , clamp = True ):
228
227
level = (level / _LEVEL_DENOM )
229
- min_val + (max_val - min_val ) * level
228
+ level = min_val + (max_val - min_val ) * level
230
229
if clamp :
231
230
level = max (min_val , min (max_val , level ))
232
231
return level ,
@@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
552
551
hparams = hparams or _HPARAMS_DEFAULT
553
552
if name == 'original' :
554
553
return auto_augment_policy_original (hparams )
555
- elif name == 'originalr' :
554
+ if name == 'originalr' :
556
555
return auto_augment_policy_originalr (hparams )
557
- elif name == 'v0' :
556
+ if name == 'v0' :
558
557
return auto_augment_policy_v0 (hparams )
559
- elif name == 'v0r' :
558
+ if name == 'v0r' :
560
559
return auto_augment_policy_v0r (hparams )
561
- elif name == '3a' :
560
+ if name == '3a' :
562
561
return auto_augment_policy_3a (hparams )
563
- else :
564
- assert False , 'Unknown AA policy (%s)' % name
562
+ assert False , f'Unknown AA policy { name } '
565
563
566
564
567
565
class AutoAugment :
@@ -576,7 +574,7 @@ def __call__(self, img):
576
574
return img
577
575
578
576
def __repr__ (self ):
579
- fs = self .__class__ .__name__ + f '(policy='
577
+ fs = self .__class__ .__name__ + '(policy='
580
578
for p in self .policy :
581
579
fs += '\n \t ['
582
580
fs += ', ' .join ([str (op ) for op in p ])
@@ -636,7 +634,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
636
634
'ShearY' ,
637
635
'TranslateXRel' ,
638
636
'TranslateYRel' ,
639
- #'Cutout' # NOTE I've implement this as random erasing separately
637
+ # 'Cutout' # NOTE I've implement this as random erasing separately
640
638
]
641
639
642
640
@@ -656,7 +654,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
656
654
'ShearY' ,
657
655
'TranslateXRel' ,
658
656
'TranslateYRel' ,
659
- #'Cutout' # NOTE I've implement this as random erasing separately
657
+ # 'Cutout' # NOTE I've implement this as random erasing separately
660
658
]
661
659
662
660
@@ -667,7 +665,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
667
665
]
668
666
669
667
670
- _RAND_CHOICE_3A = {
668
+ _RAND_WEIGHTED_3A = {
671
669
'SolarizeIncreasing' : 6 ,
672
670
'Desaturate' : 6 ,
673
671
'GaussianBlur' : 6 ,
@@ -687,7 +685,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
687
685
688
686
# These experimental weights are based loosely on the relative improvements mentioned in paper.
689
687
# They may not result in increased performance, but could likely be tuned to so.
690
- _RAND_CHOICE_WEIGHTS_0 = {
688
+ _RAND_WEIGHTED_0 = {
691
689
'Rotate' : 3 ,
692
690
'ShearX' : 2 ,
693
691
'ShearY' : 2 ,
@@ -715,13 +713,12 @@ def _get_weighted_transforms(transforms: Dict):
715
713
716
714
def rand_augment_choices (name : str , increasing = True ):
717
715
if name == 'weights' :
718
- return _RAND_CHOICE_WEIGHTS_0
719
- elif name == '3aw' :
720
- return _RAND_CHOICE_3A
721
- elif name == '3a' :
716
+ return _RAND_WEIGHTED_0
717
+ if name == '3aw' :
718
+ return _RAND_WEIGHTED_3A
719
+ if name == '3a' :
722
720
return _RAND_3A
723
- else :
724
- return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
721
+ return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
725
722
726
723
727
724
def rand_augment_ops (
0 commit comments