Skip to content

Commit 234af41

Browse files
normstercopybara-github
authored andcommitted
Corrupted MNIST dataset.
PiperOrigin-RevId: 257034402
1 parent b4bfa92 commit 234af41

File tree

9 files changed

+239
-6
lines changed

9 files changed

+239
-6
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from tensorflow_datasets.image.mnist import FashionMNIST
4848
from tensorflow_datasets.image.mnist import KMNIST
4949
from tensorflow_datasets.image.mnist import MNIST
50+
from tensorflow_datasets.image.mnist_corrupted import MNISTCorrupted
5051
from tensorflow_datasets.image.omniglot import Omniglot
5152
from tensorflow_datasets.image.open_images import OpenImagesV4
5253
from tensorflow_datasets.image.oxford_flowers102 import OxfordFlowers102

tensorflow_datasets/image/mnist.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
_MNIST_TEST_DATA_FILENAME = "t10k-images-idx3-ubyte.gz"
3636
_MNIST_TEST_LABELS_FILENAME = "t10k-labels-idx1-ubyte.gz"
3737
_MNIST_IMAGE_SIZE = 28
38-
_MNIST_IMAGE_SHAPE = (_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1)
38+
MNIST_IMAGE_SHAPE = (_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1)
39+
MNIST_NUM_CLASSES = 10
3940
_TRAIN_EXAMPLES = 60000
4041
_TEST_EXAMPLES = 10000
4142

@@ -107,8 +108,8 @@ def _info(self):
107108
builder=self,
108109
description=("The MNIST database of handwritten digits."),
109110
features=tfds.features.FeaturesDict({
110-
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
111-
"label": tfds.features.ClassLabel(num_classes=10),
111+
"image": tfds.features.Image(shape=MNIST_IMAGE_SHAPE),
112+
"label": tfds.features.ClassLabel(num_classes=MNIST_NUM_CLASSES),
112113
}),
113114
supervised_keys=("image", "label"),
114115
urls=[self.URL],
@@ -188,7 +189,7 @@ def _info(self):
188189
"classes."),
189190
features=tfds.features.FeaturesDict({
190191
"image":
191-
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
192+
tfds.features.Image(shape=MNIST_IMAGE_SHAPE),
192193
"label":
193194
tfds.features.ClassLabel(names=[
194195
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
@@ -217,7 +218,7 @@ def _info(self):
217218
"when creating Kuzushiji-MNIST."),
218219
features=tfds.features.FeaturesDict({
219220
"image":
220-
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
221+
tfds.features.Image(shape=MNIST_IMAGE_SHAPE),
221222
"label":
222223
tfds.features.ClassLabel(names=[
223224
"o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"
@@ -319,7 +320,7 @@ def _info(self):
319320
"matches the MNIST dataset."),
320321
features=tfds.features.FeaturesDict({
321322
"image":
322-
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
323+
tfds.features.Image(shape=MNIST_IMAGE_SHAPE),
323324
"label":
324325
tfds.features.ClassLabel(
325326
num_classes=self.builder_config.class_number),
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Corrupted MNIST Dataset.
17+
18+
MNISTCorrupted is a dataset generated by adding 15 corruptions to the test
19+
images in the MNIST dataset. This dataset wraps the static, corrupted MNIST
20+
test images uploaded by the original authors.
21+
"""
22+
23+
from __future__ import absolute_import
24+
from __future__ import division
25+
from __future__ import print_function
26+
27+
import os
28+
29+
import numpy as np
30+
import tensorflow as tf
31+
from tensorflow_datasets.core import api_utils
32+
from tensorflow_datasets.image import mnist
33+
import tensorflow_datasets.public_api as tfds
34+
35+
_DESCRIPTION = """\
36+
MNISTCorrupted is a dataset generated by adding 15 corruptions to the test
37+
images in the MNIST dataset. This dataset wraps the static, corrupted MNIST
38+
test images uploaded by the original authors
39+
"""
40+
41+
_CITATION = """
42+
@article{mu2019mnist,
43+
title={MNIST-C: A Robustness Benchmark for Computer Vision},
44+
author={Mu, Norman and Gilmer, Justin},
45+
journal={arXiv preprint arXiv:1906.02337},
46+
year={2019}
47+
}
48+
"""
49+
50+
_DOWNLOAD_URL = 'https://zenodo.org/record/3239543/files/mnist_c.zip'
51+
_CORRUPTIONS = [
52+
'identity',
53+
'shot_noise',
54+
'impulse_noise',
55+
'glass_blur',
56+
'motion_blur',
57+
'shear',
58+
'scale',
59+
'rotate',
60+
'brightness',
61+
'translate',
62+
'stripe',
63+
'fog',
64+
'spatter',
65+
'dotted_line',
66+
'zigzag',
67+
'canny_edges',
68+
]
69+
_DIRNAME = 'mnist_c'
70+
_TRAIN_IMAGES_FILENAME = 'train_images.npy'
71+
_TEST_IMAGES_FILENAME = 'test_images.npy'
72+
_TRAIN_LABELS_FILENAME = 'train_labels.npy'
73+
_TEST_LABELS_FILENAME = 'test_labels.npy'
74+
75+
76+
class MNISTCorruptedConfig(tfds.core.BuilderConfig):
77+
"""BuilderConfig for MNISTcorrupted."""
78+
79+
@api_utils.disallow_positional_args
80+
def __init__(self, corruption_type, **kwargs):
81+
"""Constructor.
82+
83+
Args:
84+
corruption_type: string, name of corruption from _CORRUPTIONS.
85+
**kwargs: keyword arguments forwarded to super.
86+
"""
87+
super(MNISTCorruptedConfig, self).__init__(**kwargs)
88+
self.corruption = corruption_type
89+
90+
91+
def _make_builder_configs():
92+
"""Construct a list of BuilderConfigs.
93+
94+
Construct a list of 15 MNISTCorruptedConfig objects, corresponding to
95+
the 15 corruption types.
96+
97+
Returns:
98+
A list of 15 MNISTCorruptedConfig objects.
99+
"""
100+
config_list = []
101+
for corruption in _CORRUPTIONS:
102+
config_list.append(
103+
MNISTCorruptedConfig(
104+
name=corruption,
105+
version='0.0.1',
106+
description='Corruption method: ' + corruption,
107+
corruption_type=corruption,
108+
))
109+
return config_list
110+
111+
112+
class MNISTCorrupted(tfds.core.GeneratorBasedBuilder):
113+
"""Corrupted MNIST dataset."""
114+
BUILDER_CONFIGS = _make_builder_configs()
115+
116+
def _info(self):
117+
"""Returns basic information of dataset.
118+
119+
Returns:
120+
tfds.core.DatasetInfo.
121+
"""
122+
return tfds.core.DatasetInfo(
123+
builder=self,
124+
description=_DESCRIPTION,
125+
features=tfds.features.FeaturesDict({
126+
'image':
127+
tfds.features.Image(shape=mnist.MNIST_IMAGE_SHAPE),
128+
'label':
129+
tfds.features.ClassLabel(num_classes=mnist.MNIST_NUM_CLASSES),
130+
}),
131+
supervised_keys=('image', 'label'),
132+
urls=['https://github.com/google-research/mnist-c'],
133+
citation=_CITATION)
134+
135+
def _split_generators(self, dl_manager):
136+
"""Return the train, test split of MNIST-C.
137+
138+
Args:
139+
dl_manager: download manager object.
140+
141+
Returns:
142+
train split, test split.
143+
"""
144+
path = dl_manager.download_and_extract(_DOWNLOAD_URL)
145+
return [
146+
tfds.core.SplitGenerator(
147+
name=tfds.Split.TRAIN,
148+
num_shards=1,
149+
gen_kwargs={
150+
'data_dir': os.path.join(path, _DIRNAME),
151+
'is_train': True
152+
}),
153+
tfds.core.SplitGenerator(
154+
name=tfds.Split.TEST,
155+
num_shards=1,
156+
gen_kwargs={
157+
'data_dir': os.path.join(path, _DIRNAME),
158+
'is_train': False
159+
}),
160+
]
161+
162+
def _generate_examples(self, data_dir, is_train):
163+
"""Generate corrupted MNIST data.
164+
165+
Apply corruptions to the raw images according to self.corruption_type.
166+
167+
Args:
168+
data_dir: root directory of downloaded dataset
169+
is_train: whether to return train images or test images
170+
171+
Yields:
172+
dictionary with image file and label.
173+
"""
174+
corruption = self.builder_config.corruption
175+
176+
if is_train:
177+
images_file = os.path.join(data_dir, corruption, _TRAIN_IMAGES_FILENAME)
178+
labels_file = os.path.join(data_dir, corruption, _TRAIN_LABELS_FILENAME)
179+
else:
180+
images_file = os.path.join(data_dir, corruption, _TEST_IMAGES_FILENAME)
181+
labels_file = os.path.join(data_dir, corruption, _TEST_LABELS_FILENAME)
182+
183+
with tf.io.gfile.GFile(labels_file, mode='rb') as f:
184+
labels = np.load(f)
185+
186+
with tf.io.gfile.GFile(images_file, mode='rb') as f:
187+
images = np.load(f)
188+
189+
for image, label in zip(images, labels):
190+
yield {
191+
'image': image,
192+
'label': label,
193+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for corrupted MNIST."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensorflow_datasets import testing
23+
from tensorflow_datasets.image import mnist_corrupted
24+
25+
26+
class MNISTCorruptedTest(testing.DatasetBuilderTestCase):
27+
28+
BUILDER_CONFIG_NAMES_TO_TEST = ["dotted_line"]
29+
30+
DATASET_CLASS = mnist_corrupted.MNISTCorrupted
31+
SPLITS = {
32+
"train": 2,
33+
"test": 2,
34+
}
35+
36+
if __name__ == "__main__":
37+
testing.test_main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://zenodo.org/record/3239543/files/mnist_c.zip 246661575 af9ee8c6a815870c7fdde5af84c7bf8db0bcfa1f41056db83871037fba70e493

0 commit comments

Comments
 (0)