Skip to content

Commit a266ec3

Browse files
Merge pull request #198 from us:add-emnist-dataset
PiperOrigin-RevId: 238280883
2 parents c8c00f0 + 7699041 commit a266ec3

File tree

8 files changed

+208
-34
lines changed

8 files changed

+208
-34
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensorflow_datasets.image.image_folder import ImageLabelFolder
3131
from tensorflow_datasets.image.imagenet import Imagenet2012
3232
from tensorflow_datasets.image.lsun import Lsun
33+
from tensorflow_datasets.image.mnist import EMNIST
3334
from tensorflow_datasets.image.mnist import FashionMNIST
3435
from tensorflow_datasets.image.mnist import KMNIST
3536
from tensorflow_datasets.image.mnist import MNIST

tensorflow_datasets/image/mnist.py

Lines changed: 199 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""MNIST and Fashion MNIST."""
16+
"""MNIST, Fashion MNIST, KMNIST and EMNIST."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import os
2223
import numpy as np
2324
import six.moves.urllib as urllib
2425
import tensorflow as tf
2526

27+
from tensorflow_datasets.core import api_utils
2628
import tensorflow_datasets.public_api as tfds
2729

2830
# MNIST constants
@@ -36,7 +38,6 @@
3638
_TRAIN_EXAMPLES = 60000
3739
_TEST_EXAMPLES = 10000
3840

39-
4041
_MNIST_CITATION = """\
4142
@article{lecun2010mnist,
4243
title={MNIST handwritten digit database},
@@ -47,7 +48,6 @@
4748
}
4849
"""
4950

50-
5151
_FASHION_MNIST_CITATION = """\
5252
@article{DBLP:journals/corr/abs-1708-07747,
5353
author = {Han Xiao and
@@ -67,25 +67,25 @@
6767
}
6868
"""
6969

70-
7170
_K_MNIST_CITATION = """\
72-
@article{DBLP:journals/corr/abs-1812-01718,
73-
author = {Tarin Clanuwat and
74-
Mikel Bober{-}Irizar and
75-
Asanobu Kitamoto and
76-
Alex Lamb and
77-
Kazuaki Yamamoto and
78-
David Ha},
79-
title = {Deep Learning for Classical Japanese Literature},
80-
journal = {CoRR},
81-
volume = {abs/1812.01718},
82-
year = {2018},
83-
url = {http://arxiv.org/abs/1812.01718},
84-
archivePrefix = {arXiv},
85-
eprint = {1812.01718},
86-
timestamp = {Tue, 01 Jan 2019 15:01:25 +0100},
87-
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1812-01718},
88-
bibsource = {dblp computer science bibliography, https://dblp.org}
71+
@online{clanuwat2018deep,
72+
author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
73+
title = {Deep Learning for Classical Japanese Literature},
74+
date = {2018-12-03},
75+
year = {2018},
76+
eprintclass = {cs.CV},
77+
eprinttype = {arXiv},
78+
eprint = {cs.CV/1812.01718},
79+
}
80+
"""
81+
82+
_EMNIST_CITATION = """\
83+
@article{cohen_afshar_tapson_schaik_2017,
84+
title={EMNIST: Extending MNIST to handwritten letters},
85+
DOI={10.1109/ijcnn.2017.7966217},
86+
journal={2017 International Joint Conference on Neural Networks (IJCNN)},
87+
author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
88+
year={2017}
8989
}
9090
"""
9191

@@ -118,9 +118,8 @@ def _split_generators(self, dl_manager):
118118
"test_data": _MNIST_TEST_DATA_FILENAME,
119119
"test_labels": _MNIST_TEST_LABELS_FILENAME,
120120
}
121-
mnist_files = dl_manager.download_and_extract({
122-
k: urllib.parse.urljoin(self.URL, v) for k, v in filenames.items()
123-
})
121+
mnist_files = dl_manager.download_and_extract(
122+
{k: urllib.parse.urljoin(self.URL, v) for k, v in filenames.items()})
124123

125124
# MNIST provides TRAIN and TEST splits, not a VALIDATION split, so we only
126125
# write the TRAIN and TEST splits to disk.
@@ -181,11 +180,13 @@ def _info(self):
181180
"grayscale image, associated with a label from 10 "
182181
"classes."),
183182
features=tfds.features.FeaturesDict({
184-
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
185-
"label": tfds.features.ClassLabel(names=[
186-
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
187-
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
188-
]),
183+
"image":
184+
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
185+
"label":
186+
tfds.features.ClassLabel(names=[
187+
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
188+
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
189+
]),
189190
}),
190191
supervised_keys=("image", "label"),
191192
urls=["https://github.com/zalandoresearch/fashion-mnist"],
@@ -208,17 +209,182 @@ def _info(self):
208209
"character to represent each of the 10 rows of Hiragana "
209210
"when creating Kuzushiji-MNIST."),
210211
features=tfds.features.FeaturesDict({
211-
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
212-
"label": tfds.features.ClassLabel(names=[
213-
"o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"
214-
]),
212+
"image":
213+
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
214+
"label":
215+
tfds.features.ClassLabel(names=[
216+
"o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"
217+
]),
215218
}),
216219
supervised_keys=("image", "label"),
217220
urls=["http://codh.rois.ac.jp/kmnist/index.html.en"],
218221
citation=_K_MNIST_CITATION,
219222
)
220223

221224

225+
class EMNISTConfig(tfds.core.BuilderConfig):
226+
"""BuilderConfig for EMNIST CONFIG."""
227+
228+
@api_utils.disallow_positional_args
229+
def __init__(self, class_number, train_examples, test_examples, **kwargs):
230+
"""BuilderConfig for EMNIST class number.
231+
232+
Args:
233+
class_number: There are six different splits provided in this dataset. And
234+
have different class numbers.
235+
train_examples: number of train examples
236+
test_examples: number of test examples
237+
**kwargs: keyword arguments forwarded to super.
238+
"""
239+
super(EMNISTConfig, self).__init__(**kwargs)
240+
self.class_number = class_number
241+
self.train_examples = train_examples
242+
self.test_examples = test_examples
243+
244+
245+
class EMNIST(MNIST):
246+
"""Emnist dataset."""
247+
248+
VERSION = tfds.core.Version("1.0.1")
249+
250+
BUILDER_CONFIGS = [
251+
EMNISTConfig(
252+
name="byclass",
253+
class_number=62,
254+
train_examples=697932,
255+
test_examples=116323,
256+
description="EMNIST ByClass: 814,255 characters. 62 unbalanced classes.",
257+
version="1.0.1",
258+
),
259+
EMNISTConfig(
260+
name="bymerge",
261+
class_number=47,
262+
train_examples=697932,
263+
test_examples=116323,
264+
description="EMNIST ByMerge: 814,255 characters. 47 unbalanced classes.",
265+
version="1.0.1",
266+
),
267+
EMNISTConfig(
268+
name="balanced",
269+
class_number=47,
270+
train_examples=112800,
271+
test_examples=18800,
272+
description="EMNIST Balanced: 131,600 characters. 47 balanced classes.",
273+
version="1.0.1",
274+
),
275+
EMNISTConfig(
276+
name="letters",
277+
class_number=37,
278+
train_examples=88800,
279+
test_examples=14800,
280+
description="EMNIST Letters: 103,600 characters. 26 balanced classes.",
281+
version="1.0.1",
282+
),
283+
EMNISTConfig(
284+
name="digits",
285+
class_number=10,
286+
train_examples=240000,
287+
test_examples=40000,
288+
description="EMNIST Digits: 280,000 characters. 10 balanced classes.",
289+
version="1.0.1",
290+
),
291+
EMNISTConfig(
292+
name="mnist",
293+
class_number=10,
294+
train_examples=60000,
295+
test_examples=10000,
296+
description="EMNIST MNIST: 70,000 characters. 10 balanced classes.",
297+
version="1.0.1",
298+
),
299+
EMNISTConfig(
300+
name="test",
301+
class_number=62,
302+
train_examples=10,
303+
test_examples=2,
304+
description="EMNIST test data config.",
305+
version="1.0.1",
306+
),
307+
]
308+
309+
def _info(self):
310+
return tfds.core.DatasetInfo(
311+
builder=self,
312+
description=(
313+
"The EMNIST dataset is a set of handwritten character digits"
314+
"derived from the NIST Special Database 19 and converted to"
315+
"a 28x28 pixel image format and dataset structure that directly"
316+
"matches the MNIST dataset."),
317+
features=tfds.features.FeaturesDict({
318+
"image":
319+
tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
320+
"label":
321+
tfds.features.ClassLabel(
322+
num_classes=self.builder_config.class_number),
323+
}),
324+
supervised_keys=("image", "label"),
325+
urls=["https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"],
326+
citation=_EMNIST_CITATION,
327+
)
328+
329+
def _split_generators(self, dl_manager):
330+
331+
filenames = {
332+
"train_data":
333+
"emnist-{}-train-images-idx3-ubyte".format(
334+
self.builder_config.name),
335+
"train_labels":
336+
"emnist-{}-train-labels-idx1-ubyte".format(
337+
self.builder_config.name),
338+
"test_data":
339+
"emnist-{}-test-images-idx3-ubyte".format(self.builder_config.name),
340+
"test_labels":
341+
"emnist-{}-test-labels-idx1-ubyte".format(self.builder_config.name),
342+
}
343+
344+
dir_name = dl_manager.manual_dir
345+
346+
if not tf.io.gfile.exists(os.path.join(dir_name, filenames["train_data"])):
347+
# The current tfds.core.download_manager is unable to
348+
# extract multiple and nested files.
349+
# We'll add soon! (Issue 234)
350+
msg = ("You must download and extract the dataset files manually and "
351+
"place them in : ")
352+
msg += dl_manager.manual_dir
353+
msg += """File tree must be like this :\n
354+
.
355+
| -- emnist
356+
| |-- emnist-byclass-train-images-idx3-ubyte
357+
| |-- emnist-byclass-train-labels-idx3-ubyte
358+
| |-- emnist-byclass-test-images-idx3-ubyte
359+
| |-- emnist-byclass-test-labels-idx3-ubyte
360+
| |-- emnist-bymerge-train-images-idx3-ubyte
361+
| |-- emnist-bymerge-train-labels-idx3-ubyte
362+
| |-- emnist-bymerge-test-images-idx3-ubyte
363+
| |-- emnist-bymerge-test-labels-idx3-ubyte
364+
| |-- .......
365+
"""
366+
raise Exception(msg.replace(" ", ""))
367+
368+
return [
369+
tfds.core.SplitGenerator(
370+
name=tfds.Split.TRAIN,
371+
num_shards=10,
372+
gen_kwargs=dict(
373+
num_examples=self.builder_config.train_examples,
374+
data_path=os.path.join(dir_name, filenames["train_data"]),
375+
label_path=os.path.join(dir_name, filenames["train_labels"]),
376+
)),
377+
tfds.core.SplitGenerator(
378+
name=tfds.Split.TEST,
379+
num_shards=1,
380+
gen_kwargs=dict(
381+
num_examples=self.builder_config.test_examples,
382+
data_path=os.path.join(dir_name, filenames["test_data"]),
383+
label_path=os.path.join(dir_name, filenames["test_labels"]),
384+
))
385+
]
386+
387+
222388
def _extract_mnist_images(image_filepath, num_images):
223389
with tf.io.gfile.GFile(image_filepath, "rb") as f:
224390
f.read(16) # header

tensorflow_datasets/image/mnist_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,10 @@ class KMNISTTest(MNISTTest):
5151
DATASET_CLASS = mnist.KMNIST
5252

5353

54+
class EMNISTTest(MNISTTest):
55+
DATASET_CLASS = mnist.EMNIST
56+
BUILDER_CONFIG_NAMES_TO_TEST = ["test"]
57+
58+
5459
if __name__ == "__main__":
5560
testing.test_main()

tensorflow_datasets/testing/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def write_label_file(filename, num_labels):
7171

7272

7373
def main(_):
74-
for mnist in ["mnist", "fashion_mnist", "kmnist"]:
74+
for mnist in ["mnist", "fashion_mnist", "kmnist", "emnist"]:
7575
output_dir = mnist_dir(mnist)
7676
test_utils.remake_dir(output_dir)
7777
write_image_file(os.path.join(output_dir, _TRAIN_DATA_FILENAME), 10)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11111111
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11111111

0 commit comments

Comments
 (0)