Skip to content

Commit a9a6b82

Browse files
committed
Added EMNIST Dataset
1 parent 93f12d3 commit a9a6b82

File tree

2 files changed

+152
-8
lines changed

2 files changed

+152
-8
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorflow_datasets.image.mnist import FashionMNIST
3434
from tensorflow_datasets.image.mnist import MNIST
3535
from tensorflow_datasets.image.mnist import KMNIST
36+
from tensorflow_datasets.image.mnist import EMNIST
3637
from tensorflow_datasets.image.omniglot import Omniglot
3738
from tensorflow_datasets.image.open_images import OpenImagesV4
3839
from tensorflow_datasets.image.quickdraw import QuickdrawBitmap

tensorflow_datasets/image/mnist.py

Lines changed: 151 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""MNIST and Fashion MNIST."""
16+
"""MNIST, Fashion MNIST, KMNIST and EMNIST."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division
@@ -23,6 +23,7 @@
2323
import six.moves.urllib as urllib
2424
import tensorflow as tf
2525

26+
from tensorflow_datasets.core import api_utils
2627
import tensorflow_datasets.public_api as tfds
2728

2829
# MNIST constants
@@ -68,7 +69,7 @@
6869
"""
6970

7071

71-
_K_MNIST_CITATION ="""
72+
_K_MNIST_CITATION = """\
7273
@online{clanuwat2018deep,
7374
author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
7475
title = {Deep Learning for Classical Japanese Literature},
@@ -77,7 +78,17 @@
7778
eprintclass = {cs.CV},
7879
eprinttype = {arXiv},
7980
eprint = {cs.CV/1812.01718},
80-
}
81+
}
82+
"""
83+
84+
_EMNIST_CITATION = """\
85+
@article{cohen_afshar_tapson_schaik_2017,
86+
title={EMNIST: Extending MNIST to handwritten letters},
87+
DOI={10.1109/ijcnn.2017.7966217},
88+
journal={2017 International Joint Conference on Neural Networks (IJCNN)},
89+
author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
90+
year={2017}
91+
}
8192
"""
8293

8394
class MNIST(tfds.core.GeneratorBasedBuilder):
@@ -207,6 +218,143 @@ def _info(self):
207218
citation=_K_MNIST_CITATION,
208219
)
209220

221+
class EMNISTConfig(tfds.core.BuilderConfig):
222+
"""BuilderConfig for EMNIST CONFIG."""
223+
224+
@api_utils.disallow_positional_args
225+
def __init__(self, class_number, train_examples, test_examples, **kwargs):
226+
"""BuilderConfig for EMNIST class number.
227+
228+
Args:
229+
class_number: There are six different splits provided in this dataset. And have
230+
different class numbers.
231+
232+
train_examples, test_examples: So in these have different test and train character
233+
numbers.
234+
235+
**kwargs: keyword arguments forwarded to super.
236+
"""
237+
super(EMNISTConfig, self).__init__(**kwargs)
238+
self.class_number = class_number
239+
self.train_examples = train_examples
240+
self.test_examples = test_examples
241+
242+
243+
class EMNIST(MNIST):
244+
245+
VERSION = tfds.core.Version('1.0.0')
246+
247+
BUILDER_CONFIGS = [
248+
EMNISTConfig(
249+
name="byclass",
250+
class_number=62,
251+
train_examples=697932,
252+
test_examples=116323,
253+
description="EMNIST ByClass: 814,255 characters. 62 unbalanced classes.",
254+
version="0.1.1",
255+
),
256+
EMNISTConfig(
257+
name="bymerge",
258+
class_number=47,
259+
train_examples=697932,
260+
test_examples=116323,
261+
description="EMNIST ByMerge: 814,255 characters. 47 unbalanced classes.",
262+
version="0.1.1",
263+
),
264+
EMNISTConfig(
265+
name="balanced",
266+
class_number=47,
267+
train_examples=112800,
268+
test_examples=18800,
269+
description="EMNIST Balanced: 131,600 characters. 47 balanced classes.",
270+
version="0.1.1",
271+
),
272+
EMNISTConfig(
273+
name="letters",
274+
class_number=37,
275+
train_examples=88800,
276+
test_examples=14800,
277+
description="EMNIST Letters: 103,600 characters. 26 balanced classes.",
278+
version="0.1.1",
279+
),
280+
EMNISTConfig(
281+
name="digits",
282+
class_number=10,
283+
train_examples=240000,
284+
test_examples=40000,
285+
description="EMNIST Digits: 280,000 characters. 10 balanced classes.",
286+
version="0.1.1",
287+
),
288+
EMNISTConfig(
289+
name="mnist",
290+
class_number=10,
291+
train_examples=60000,
292+
test_examples=10000,
293+
description="EMNIST MNIST: 70,000 characters. 10 balanced classes.",
294+
version="0.1.1",
295+
),
296+
EMNISTConfig(
297+
name="test",
298+
class_number=62,
299+
train_examples=10,
300+
test_examples=2,
301+
description="EMNIST test data config.",
302+
version="0.1.1",
303+
),
304+
]
305+
306+
def _info(self):
307+
return tfds.core.DatasetInfo(
308+
builder=self,
309+
description=("The EMNIST dataset is a set of handwritten character digits"
310+
"derived from the NIST Special Database 19 and converted to"
311+
"a 28x28 pixel image format and dataset structure that directly"
312+
"matches the MNIST dataset."
313+
),
314+
features=tfds.features.FeaturesDict({
315+
"image": tfds.features.Image(shape=_MNIST_IMAGE_SHAPE),
316+
"label": tfds.features.ClassLabel(num_classes=self.builder_config.class_number),
317+
318+
}),
319+
supervised_keys=("image", "label"),
320+
urls=["https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"],
321+
citation=_EMNIST_CITATION,
322+
)
323+
324+
def _split_generators(self, dl_manager):
325+
326+
filenames = {
327+
"train_data": 'emnist-{}-train-images-idx3-ubyte'.format(self.builder_config.name),
328+
"train_labels": 'emnist-{}-train-labels-idx1-ubyte'.format(self.builder_config.name),
329+
"test_data": 'emnist-{}-test-images-idx3-ubyte'.format(self.builder_config.name),
330+
"test_labels": 'emnist-{}-test-labels-idx1-ubyte'.format(self.builder_config.name),
331+
}
332+
dir_name = dl_manager.manual_dir
333+
import os
334+
return [
335+
tfds.core.SplitGenerator(
336+
name=tfds.Split.TRAIN,
337+
num_shards=10,
338+
gen_kwargs=dict(
339+
num_examples=self.builder_config.train_examples,
340+
data_path=os.path.join(dir_name, filenames['train_data']),
341+
label_path=os.path.join(dir_name, filenames["train_labels"]),
342+
)
343+
344+
),
345+
346+
tfds.core.SplitGenerator(
347+
name=tfds.Split.TEST,
348+
num_shards=1,
349+
gen_kwargs=dict(
350+
num_examples=self.builder_config.test_examples,
351+
data_path=os.path.join(dir_name, filenames['test_data']),
352+
label_path=os.path.join(dir_name, filenames["test_labels"]),
353+
)
354+
)
355+
]
356+
357+
210358

211359

212360
def _extract_mnist_images(image_filepath, num_images):
@@ -226,8 +374,3 @@ def _extract_mnist_labels(labels_filepath, num_labels):
226374
buf = f.read(num_labels)
227375
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
228376
return labels
229-
230-
231-
232-
# test file
233-
# and full test

0 commit comments

Comments
 (0)