Skip to content

Commit 1d77831

Browse files
normstercopybara-github
authored andcommitted
Implement 3 missing Imagenet-C corruptions. Resize and crop images before applying corruptions. Save corrupted images as JPEG files. Change corruption names to reflect names in original release.
PiperOrigin-RevId: 289176965
1 parent 5e7c62e commit 1d77831

File tree

10 files changed

+230
-36
lines changed

10 files changed

+230
-36
lines changed

tensorflow_datasets/image/corruptions.py

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,33 @@
1515

1616
"""Common corruptions to images.
1717
18-
Define 12+4 common image corruptions: Gaussian noise, shot noise, impulse_noise,
18+
Define 15+4 common image corruptions: Gaussian noise, shot noise, impulse_noise,
1919
defocus blur, frosted glass blur, zoom blur, fog, brightness, contrast, elastic,
20-
pixelate, jpeg compression.
20+
pixelate, jpeg compression, frost, snow, and motion blur.
2121
22-
4 extra corruptions include gaussian blur, saturate, spatter, and speckle
23-
noise.
22+
4 extra corruptions: gaussian blur, saturate, spatter, and speckle noise.
2423
"""
2524

2625
from __future__ import absolute_import
2726
from __future__ import division
2827
from __future__ import print_function
2928

3029
import io
30+
import subprocess
31+
import tempfile
3132
import numpy as np
33+
import tensorflow.compat.v2 as tf
3234
import tensorflow_datasets.public_api as tfds
3335

36+
37+
# To be populated by download_manager
38+
FROST_FILENAMES = []
39+
40+
41+
def _imagemagick_bin():
42+
return 'imagemagick' # pylint: disable=unreachable
43+
44+
3445
# /////////////// Corruption Helpers ///////////////
3546

3647

@@ -239,7 +250,7 @@ def defocus_blur(x, severity=1):
239250
return around_and_astype(x_clip)
240251

241252

242-
def frosted_glass_blur(x, severity=1):
253+
def glass_blur(x, severity=1):
243254
"""Frosted glass blurring to images.
244255
245256
Apply frosted glass blurring to images by shuffling pixels locally.
@@ -367,7 +378,7 @@ def contrast(x, severity=1):
367378
return around_and_astype(x_clip)
368379

369380

370-
def elastic(x, severity=1):
381+
def elastic_transform(x, severity=1):
371382
"""Conduct elastic transform to images.
372383
373384
Elastic transform is performed on small patches of the images.
@@ -469,6 +480,131 @@ def jpeg_compression(x, severity=1):
469480
return np.asarray(x)
470481

471482

483+
def frost(x, severity=1):
484+
"""Apply frost to images.
485+
486+
Args:
487+
x: numpy array, uncorrupted image, assumed to have uint8 pixel in [0,255].
488+
severity: integer, severity of corruption.
489+
490+
Returns:
491+
numpy array, image with uint8 pixels in [0,255]. Applied frost.
492+
"""
493+
c = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1]
494+
filename = FROST_FILENAMES[np.random.randint(5)]
495+
with tempfile.NamedTemporaryFile() as im_frost:
496+
tf.io.gfile.copy(filename, im_frost.name, overwrite=True)
497+
frost_img = tfds.core.lazy_imports.cv2.imread(im_frost.name)
498+
# randomly crop and convert to rgb
499+
x_start, y_start = np.random.randint(
500+
0, frost_img.shape[0] - 224), np.random.randint(0,
501+
frost_img.shape[1] - 224)
502+
frost_img = frost_img[x_start:x_start + 224, y_start:y_start + 224][...,
503+
[2, 1, 0]]
504+
505+
x = np.clip(c[0] * np.array(x) + c[1] * frost_img, 0, 255)
506+
507+
return around_and_astype(x)
508+
509+
510+
def snow(x, severity=1):
511+
"""Apply snow to images.
512+
513+
Args:
514+
x: numpy array, uncorrupted image, assumed to have uint8 pixel in [0,255].
515+
severity: integer, severity of corruption.
516+
517+
Returns:
518+
numpy array, image with uint8 pixels in [0,255]. Applied snow.
519+
"""
520+
cv2 = tfds.core.lazy_imports.cv2
521+
PIL_Image = tfds.core.lazy_imports.PIL_Image # pylint: disable=invalid-name
522+
c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8), (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
523+
(0.55, 0.3, 4, 0.9, 12, 8, 0.7), (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65),
524+
(0.55, 0.3, 2.5, 0.85, 12, 12, 0.55)][severity - 1]
525+
526+
x = np.array(x, dtype=np.float32) / 255.
527+
snow_layer = np.random.normal(
528+
size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
529+
530+
snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
531+
snow_layer[snow_layer < c[3]] = 0
532+
533+
snow_layer = PIL_Image.fromarray(
534+
(np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
535+
536+
with tempfile.NamedTemporaryFile() as im_input:
537+
with tempfile.NamedTemporaryFile() as im_output:
538+
snow_layer.save(im_input.name, format='PNG')
539+
540+
convert_bin = _imagemagick_bin()
541+
radius = c[4]
542+
sigma = c[5]
543+
angle = np.random.uniform(-135, -45)
544+
545+
subprocess.check_output([
546+
convert_bin, '-motion-blur', '{}x{}+{}'.format(radius, sigma, angle),
547+
im_input.name, im_output.name
548+
])
549+
550+
with open(im_output.name, 'rb') as f:
551+
output = f.read()
552+
553+
snow_layer = cv2.imdecode(
554+
np.fromstring(output, np.uint8), cv2.IMREAD_UNCHANGED) / 255.
555+
snow_layer = snow_layer[..., np.newaxis]
556+
557+
x = c[6] * x + (1 - c[6]) * np.maximum(
558+
x,
559+
cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(224, 224, 1) * 1.5 + 0.5)
560+
x = np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
561+
562+
return around_and_astype(x)
563+
564+
565+
def motion_blur(x, severity=1):
566+
"""Apply motion blur to images.
567+
568+
Args:
569+
x: numpy array, uncorrupted image, assumed to have uint8 pixel in [0,255].
570+
severity: integer, severity of corruption.
571+
572+
Returns:
573+
numpy array, image with uint8 pixels in [0,255]. Applied motion blur.
574+
"""
575+
c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)][severity - 1]
576+
577+
x = tfds.core.lazy_imports.PIL_Image.fromarray(x.astype(np.uint8))
578+
579+
with tempfile.NamedTemporaryFile() as im_input:
580+
with tempfile.NamedTemporaryFile() as im_output:
581+
x.save(im_input.name, format='PNG')
582+
583+
convert_bin = _imagemagick_bin()
584+
radius = c[0]
585+
sigma = c[1]
586+
angle = np.random.uniform(-45, -45)
587+
588+
subprocess.check_output([
589+
convert_bin, '-motion-blur', '{}x{}+{}'.format(radius, sigma, angle),
590+
im_input.name, im_output.name
591+
])
592+
593+
with open(im_output.name, 'rb') as f:
594+
output = f.read()
595+
596+
x = tfds.core.lazy_imports.cv2.imdecode(
597+
np.fromstring(output, np.uint8),
598+
tfds.core.lazy_imports.cv2.IMREAD_UNCHANGED)
599+
600+
if x.shape != (224, 224):
601+
x = np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB
602+
else: # greyscale to RGB
603+
x = np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)
604+
605+
return around_and_astype(x)
606+
607+
472608
# /////////////// Extra Corruptions ///////////////
473609

474610

@@ -484,7 +620,7 @@ def gaussian_blur(x, severity=1):
484620
"""
485621
c = [1, 2, 3, 4, 6][severity - 1]
486622

487-
x = tfds.core.lazy_imports.gaussian(
623+
x = tfds.core.lazy_imports.skimage.filters.gaussian(
488624
np.array(x) / 255., sigma=c, multichannel=True)
489625
x = np.clip(x, 0, 1) * 255
490626

@@ -543,7 +679,7 @@ def spatter(x, severity=1):
543679
# ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
544680
# ker -= np.mean(ker)
545681
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
546-
dist = cv2.filter2D(dist, cv2.CV_8U, ker)
682+
dist = cv2.filter2D(dist, cv2.CVX_8U, ker)
547683
dist = cv2.blur(dist, (3, 3)).astype(np.float32)
548684

549685
m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)

tensorflow_datasets/image/imagenet2012_corrupted.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030
_DESCRIPTION = """\
3131
Imagenet2012Corrupted is a dataset generated by adding common corruptions to the
3232
images in the ImageNet dataset. In the original paper, there are 15 + 4
33-
different corruptions, and each has 5 levels of severity. In this dataset, we
34-
implement 12 out of the 15 corruptions, including Gaussian noise, shot noise,
35-
impulse_noise, defocus blur, frosted glass blur, zoom blur, fog, brightness,
36-
contrast, elastic, pixelate, and jpeg compression. We also implement the 4 extra
33+
different corruptions, and each has 5 levels of severity. We also implement the 4 extra
3734
corruptions gaussian blur, saturate, spatter, and speckle noise. The randomness
3835
is fixed so that regeneration is deterministic.
3936
"""
@@ -56,25 +53,37 @@
5653
# tar file).
5754
_VALIDATION_LABELS_FNAME = 'image/imagenet2012_validation_labels.txt'
5855

59-
# TODO(normanmu): implement frost, snow, and motion blur once wand library is
60-
# upgraded (cl/262186801)
56+
_FROST_FILEBASE = 'https://raw.githubusercontent.com/hendrycks/robustness/master/ImageNet-C/imagenet_c/imagenet_c/frost/'
57+
_FROST_FILENAMES = [
58+
_FROST_FILEBASE + f for f in [
59+
'frost1.png', 'frost2.png', 'frost3.png', 'frost4.jpg', 'frost5.jpg',
60+
'frost6.jpg'
61+
]
62+
]
63+
6164
BENCHMARK_CORRUPTIONS = [
6265
'gaussian_noise',
6366
'shot_noise',
6467
'impulse_noise',
6568
'defocus_blur',
66-
'frosted_glass_blur',
69+
'glass_blur',
70+
'motion_blur',
6771
'zoom_blur',
72+
'snow',
73+
'frost',
6874
'fog',
6975
'brightness',
7076
'contrast',
71-
'elastic',
77+
'elastic_transform',
7278
'pixelate',
7379
'jpeg_compression',
7480
]
7581

7682
EXTRA_CORRUPTIONS = ['gaussian_blur', 'saturate', 'spatter', 'speckle_noise']
7783

84+
_IMAGE_SIZE = 224
85+
_CROP_PADDING = 32
86+
7887

7988
class Imagenet2012CorruptedConfig(tfds.core.BuilderConfig):
8089
"""BuilderConfig for Imagenet2012Corrupted."""
@@ -94,23 +103,21 @@ def __init__(self, corruption_type=None, severity=1, **kwargs):
94103
self.severity = severity
95104

96105

97-
_VERSION = tfds.core.Version(
98-
'0.0.1', experiments={tfds.core.Experiment.S3: False})
99-
_SUPPORTED_VERSIONS = [
100-
tfds.core.Version('3.0.1', (
101-
'New split API (https://tensorflow.org/datasets/splits); fix colorization (all RGB) and '
102-
'format (all jpeg); use TAR_STREAM.')),
103-
]
106+
_VERSION = tfds.core.Version('3.1.0')
107+
108+
# Version history:
109+
# 3.1.0: Implement missing corruptions. Fix crop/resize ordering, file encoding
110+
# 0.0.1: Initial dataset.
104111

105112

106113
def _make_builder_configs():
107114
"""Construct a list of BuilderConfigs.
108115
109-
Construct a list of 80 Imagenet2012CorruptedConfig objects, corresponding to
110-
the 12 + 4 corruption types, with each type having 5 severities.
116+
Construct a list of 95 Imagenet2012CorruptedConfig objects, corresponding to
117+
the 15 + 4 corruption types, with each type having 5 severities.
111118
112119
Returns:
113-
A list of Imagenet2012CorruptedConfig objects.
120+
A list of 95 Imagenet2012CorruptedConfig objects.
114121
"""
115122
config_list = []
116123
for each_corruption in BENCHMARK_CORRUPTIONS + EXTRA_CORRUPTIONS:
@@ -122,14 +129,37 @@ def _make_builder_configs():
122129
Imagenet2012CorruptedConfig(
123130
name=name_str,
124131
version=_VERSION,
125-
supported_versions=_SUPPORTED_VERSIONS,
126132
description=description_str,
127133
corruption_type=each_corruption,
128134
severity=each_severity,
129135
))
130136
return config_list
131137

132138

139+
def _decode_and_center_crop(image_bytes):
140+
"""Crops to center of image with padding then scales image size."""
141+
shape = tf.image.extract_jpeg_shape(image_bytes)
142+
image_height = shape[0]
143+
image_width = shape[1]
144+
145+
padded_center_crop_size = tf.cast(
146+
((_IMAGE_SIZE / (_IMAGE_SIZE + _CROP_PADDING)) *
147+
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
148+
149+
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
150+
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
151+
crop_window = tf.stack([
152+
offset_height, offset_width, padded_center_crop_size,
153+
padded_center_crop_size
154+
])
155+
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
156+
image = tf.image.resize([image], [_IMAGE_SIZE, _IMAGE_SIZE],
157+
method=tf.image.ResizeMethod.BICUBIC)[0]
158+
image = tf.cast(image, tf.int32)
159+
160+
return image
161+
162+
133163
class Imagenet2012Corrupted(Imagenet2012):
134164
"""Corrupted ImageNet2012 dataset."""
135165
BUILDER_CONFIGS = _make_builder_configs()
@@ -145,15 +175,27 @@ def _info(self):
145175
builder=self,
146176
description=_DESCRIPTION,
147177
features=tfds.features.FeaturesDict({
148-
'image': tfds.features.Image(),
149-
'label': tfds.features.ClassLabel(names_file=names_file),
150-
'file_name': tfds.features.Text(), # Eg: 'n15075141_54.JPEG'
178+
'image':
179+
tfds.features.Image(
180+
shape=(_IMAGE_SIZE, _IMAGE_SIZE, 3),
181+
encoding_format='jpeg'),
182+
'label':
183+
tfds.features.ClassLabel(names_file=names_file),
184+
'file_name':
185+
tfds.features.Text(), # Eg: 'n15075141_54.JPEG'
151186
}),
152187
supervised_keys=('image', 'label'),
153188
homepage='https://openreview.net/forum?id=HJz6tiCqYm',
154189
citation=_CITATION,
155190
)
156191

192+
def _split_generators(self, dl_manager):
193+
"""Filter out training split as ImageNet-C is a testing benchmark."""
194+
splits = super(Imagenet2012Corrupted, self)._split_generators(dl_manager)
195+
196+
corruptions.FROST_FILENAMES = dl_manager.download(_FROST_FILENAMES)
197+
return [s for s in splits if s.name != tfds.Split.TRAIN]
198+
157199
def _generate_examples(self, archive, validation_labels=None):
158200
"""Generate corrupted imagenet validation data.
159201
@@ -177,7 +219,7 @@ def _generate_examples(self, archive, validation_labels=None):
177219
gen_fn = super(Imagenet2012Corrupted, self)._generate_examples
178220
for key, example in gen_fn(archive, validation_labels):
179221
with tf.Graph().as_default():
180-
tf_img = tf.image.decode_jpeg(example['image'].read(), channels=3)
222+
tf_img = _decode_and_center_crop(example['image'].read())
181223
image_np = tfds.as_numpy(tf_img)
182224
example['image'] = self._get_corrupted_example(image_np)
183225

@@ -196,18 +238,22 @@ def _get_corrupted_example(self, x):
196238
"""
197239
corruption_type = self.builder_config.corruption_type
198240
severity = self.builder_config.severity
241+
x = np.clip(x, 0, 255)
199242

200243
return {
201244
'gaussian_noise': corruptions.gaussian_noise,
202245
'shot_noise': corruptions.shot_noise,
203246
'impulse_noise': corruptions.impulse_noise,
204247
'defocus_blur': corruptions.defocus_blur,
205-
'frosted_glass_blur': corruptions.frosted_glass_blur,
248+
'glass_blur': corruptions.glass_blur,
249+
'motion_blur': corruptions.motion_blur,
206250
'zoom_blur': corruptions.zoom_blur,
251+
'snow': corruptions.snow,
252+
'frost': corruptions.frost,
207253
'fog': corruptions.fog,
208254
'brightness': corruptions.brightness,
209255
'contrast': corruptions.contrast,
210-
'elastic': corruptions.elastic,
256+
'elastic_transform': corruptions.elastic_transform,
211257
'pixelate': corruptions.pixelate,
212258
'jpeg_compression': corruptions.jpeg_compression,
213259
'gaussian_blur': corruptions.gaussian_blur,

0 commit comments

Comments
 (0)