Skip to content

Commit 8ab84e4

Browse files
committed
#9 : fix augmentation search space
1 parent 6949784 commit 8ab84e4

File tree

1 file changed

+56
-32
lines changed

1 file changed

+56
-32
lines changed

RandAugment/augmentations.py

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
66
import numpy as np
77
import torch
8+
from PIL import Image
89

910

1011
def ShearX(img, v): # [-0.3, 0.3]
@@ -29,23 +30,23 @@ def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
2930
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
3031

3132

32-
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
33-
assert -0.45 <= v <= 0.45
33+
def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
34+
assert 0 <= v
3435
if random.random() > 0.5:
3536
v = -v
36-
v = v * img.size[1]
37-
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
37+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
3838

3939

40-
def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
41-
assert 0 <= v <= 10
40+
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
41+
assert -0.45 <= v <= 0.45
4242
if random.random() > 0.5:
4343
v = -v
44-
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
44+
v = v * img.size[1]
45+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
4546

4647

47-
def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
48-
assert 0 <= v <= 10
48+
def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
49+
assert 0 <= v
4950
if random.random() > 0.5:
5051
v = -v
5152
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
@@ -79,14 +80,16 @@ def Solarize(img, v): # [0, 256]
7980
return PIL.ImageOps.solarize(img, v)
8081

8182

82-
def Posterize(img, v): # [4, 8]
83-
assert 4 <= v <= 8
84-
v = int(v)
85-
return PIL.ImageOps.posterize(img, v)
83+
def SolarizeAdd(img, addition=0, threshold=128):
84+
img_np = np.array(img).astype(np.int)
85+
img_np = img_np + addition
86+
img_np = np.clip(img_np, 0, 255)
87+
img_np = img_np.astype(np.uint8)
88+
img = Image.fromarray(img_np)
89+
return PIL.ImageOps.solarize(img, threshold)
8690

8791

88-
def Posterize2(img, v): # [0, 4]
89-
assert 0 <= v <= 4
92+
def Posterize(img, v): # [4, 8]
9093
v = int(v)
9194
return PIL.ImageOps.posterize(img, v)
9295

@@ -156,25 +159,46 @@ def Identity(img, v):
156159

157160
def augment_list(): # 16 oeprations and their ranges
158161
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
162+
# l = [
163+
# (Identity, 0., 1.0),
164+
# (ShearX, 0., 0.3), # 0
165+
# (ShearY, 0., 0.3), # 1
166+
# (TranslateX, 0., 0.33), # 2
167+
# (TranslateY, 0., 0.33), # 3
168+
# (Rotate, 0, 30), # 4
169+
# (AutoContrast, 0, 1), # 5
170+
# (Invert, 0, 1), # 6
171+
# (Equalize, 0, 1), # 7
172+
# (Solarize, 0, 110), # 8
173+
# (Posterize, 4, 8), # 9
174+
# # (Contrast, 0.1, 1.9), # 10
175+
# (Color, 0.1, 1.9), # 11
176+
# (Brightness, 0.1, 1.9), # 12
177+
# (Sharpness, 0.1, 1.9), # 13
178+
# # (Cutout, 0, 0.2), # 14
179+
# # (SamplePairing(imgs), 0, 0.4), # 15
180+
# ]
181+
182+
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
159183
l = [
160-
(Identity, 0., 1.0),
161-
(ShearX, 0., 0.3), # 0
162-
(ShearY, 0., 0.3), # 1
163-
(TranslateX, 0., 0.33), # 2
164-
(TranslateY, 0., 0.33), # 3
165-
(Rotate, 0, 30), # 4
166-
(AutoContrast, 0, 1), # 5
167-
(Invert, 0, 1), # 6
168-
(Equalize, 0, 1), # 7
169-
(Solarize, 0, 110), # 8
170-
(Posterize, 4, 8), # 9
171-
# (Contrast, 0.1, 1.9), # 10
172-
(Color, 0.1, 1.9), # 11
173-
(Brightness, 0.1, 1.9), # 12
174-
(Sharpness, 0.1, 1.9), # 13
175-
# (Cutout, 0, 0.2), # 14
176-
# (SamplePairing(imgs), 0, 0.4), # 15
184+
(AutoContrast, 0, 1),
185+
(Equalize, 0, 1),
186+
(Invert, 0, 1),
187+
(Rotate, 0, 30),
188+
(Posterize, 0, 4),
189+
(Solarize, 0, 256),
190+
(SolarizeAdd, 0, 110),
191+
(Color, 0.1, 1.9),
192+
(Contrast, 0.1, 1.9),
193+
(Brightness, 0.1, 1.9),
194+
(Sharpness, 0.1, 1.9),
195+
(ShearX, 0., 0.3),
196+
(ShearY, 0., 0.3),
197+
(CutoutAbs, 0, 40),
198+
(TranslateXabs, 0., 100),
199+
(TranslateYabs, 0., 100),
177200
]
201+
178202
return l
179203

180204

0 commit comments

Comments
 (0)