Skip to content

Commit 3a636ee

Browse files
committed
Fix #1713 missed assignement in 3-aug level fn, fix few other minor lint complaints in auto_augment.py
1 parent 82cb47b commit 3a636ee

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

timm/data/auto_augment.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def _interpolation(kwargs):
5454
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
5555
if isinstance(interpolation, (list, tuple)):
5656
return random.choice(interpolation)
57-
else:
58-
return interpolation
57+
return interpolation
5958

6059

6160
def _check_args_tf(kwargs):
@@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
10099
_check_args_tf(kwargs)
101100
if _PIL_VER >= (5, 2):
102101
return img.rotate(degrees, **kwargs)
103-
elif _PIL_VER >= (5, 0):
102+
if _PIL_VER >= (5, 0):
104103
w, h = img.size
105104
post_trans = (0, 0)
106105
rotn_center = (w / 2.0, h / 2.0)
@@ -124,8 +123,7 @@ def transform(x, y, matrix):
124123
matrix[2] += rotn_center[0]
125124
matrix[5] += rotn_center[1]
126125
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
127-
else:
128-
return img.rotate(degrees, resample=kwargs['resample'])
126+
return img.rotate(degrees, resample=kwargs['resample'])
129127

130128

131129
def auto_contrast(img, **__):
@@ -151,12 +149,13 @@ def solarize_add(img, add, thresh=128, **__):
151149
lut.append(min(255, i + add))
152150
else:
153151
lut.append(i)
152+
154153
if img.mode in ("L", "RGB"):
155154
if img.mode == "RGB" and len(lut) == 256:
156155
lut = lut + lut + lut
157156
return img.point(lut)
158-
else:
159-
return img
157+
158+
return img
160159

161160

162161
def posterize(img, bits_to_keep, **__):
@@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
226225

227226
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
228227
level = (level / _LEVEL_DENOM)
229-
min_val + (max_val - min_val) * level
228+
level = min_val + (max_val - min_val) * level
230229
if clamp:
231230
level = max(min_val, min(max_val, level))
232231
return level,
@@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
552551
hparams = hparams or _HPARAMS_DEFAULT
553552
if name == 'original':
554553
return auto_augment_policy_original(hparams)
555-
elif name == 'originalr':
554+
if name == 'originalr':
556555
return auto_augment_policy_originalr(hparams)
557-
elif name == 'v0':
556+
if name == 'v0':
558557
return auto_augment_policy_v0(hparams)
559-
elif name == 'v0r':
558+
if name == 'v0r':
560559
return auto_augment_policy_v0r(hparams)
561-
elif name == '3a':
560+
if name == '3a':
562561
return auto_augment_policy_3a(hparams)
563-
else:
564-
assert False, 'Unknown AA policy (%s)' % name
562+
assert False, f'Unknown AA policy {name}'
565563

566564

567565
class AutoAugment:
@@ -576,7 +574,7 @@ def __call__(self, img):
576574
return img
577575

578576
def __repr__(self):
579-
fs = self.__class__.__name__ + f'(policy='
577+
fs = self.__class__.__name__ + '(policy='
580578
for p in self.policy:
581579
fs += '\n\t['
582580
fs += ', '.join([str(op) for op in p])
@@ -636,7 +634,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
636634
'ShearY',
637635
'TranslateXRel',
638636
'TranslateYRel',
639-
#'Cutout' # NOTE I've implement this as random erasing separately
637+
# 'Cutout' # NOTE I've implement this as random erasing separately
640638
]
641639

642640

@@ -656,7 +654,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
656654
'ShearY',
657655
'TranslateXRel',
658656
'TranslateYRel',
659-
#'Cutout' # NOTE I've implement this as random erasing separately
657+
# 'Cutout' # NOTE I've implement this as random erasing separately
660658
]
661659

662660

@@ -667,7 +665,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
667665
]
668666

669667

670-
_RAND_CHOICE_3A = {
668+
_RAND_WEIGHTED_3A = {
671669
'SolarizeIncreasing': 6,
672670
'Desaturate': 6,
673671
'GaussianBlur': 6,
@@ -687,7 +685,7 @@ def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
687685

688686
# These experimental weights are based loosely on the relative improvements mentioned in paper.
689687
# They may not result in increased performance, but could likely be tuned to so.
690-
_RAND_CHOICE_WEIGHTS_0 = {
688+
_RAND_WEIGHTED_0 = {
691689
'Rotate': 3,
692690
'ShearX': 2,
693691
'ShearY': 2,
@@ -715,13 +713,12 @@ def _get_weighted_transforms(transforms: Dict):
715713

716714
def rand_augment_choices(name: str, increasing=True):
717715
if name == 'weights':
718-
return _RAND_CHOICE_WEIGHTS_0
719-
elif name == '3aw':
720-
return _RAND_CHOICE_3A
721-
elif name == '3a':
716+
return _RAND_WEIGHTED_0
717+
if name == '3aw':
718+
return _RAND_WEIGHTED_3A
719+
if name == '3a':
722720
return _RAND_3A
723-
else:
724-
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
721+
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
725722

726723

727724
def rand_augment_ops(

0 commit comments

Comments
 (0)