13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- """MNIST and Fashion MNIST."""
16
+ """MNIST, Fashion MNIST, KMNIST and EMNIST ."""
17
17
18
18
from __future__ import absolute_import
19
19
from __future__ import division
20
20
from __future__ import print_function
21
21
22
+ import os
22
23
import numpy as np
23
24
import six .moves .urllib as urllib
24
25
import tensorflow as tf
25
26
27
+ from tensorflow_datasets .core import api_utils
26
28
import tensorflow_datasets .public_api as tfds
27
29
28
30
# MNIST constants
36
38
_TRAIN_EXAMPLES = 60000
37
39
_TEST_EXAMPLES = 10000
38
40
39
-
40
41
_MNIST_CITATION = """\
41
42
@article{lecun2010mnist,
42
43
title={MNIST handwritten digit database},
47
48
}
48
49
"""
49
50
50
-
51
51
_FASHION_MNIST_CITATION = """\
52
52
@article{DBLP:journals/corr/abs-1708-07747,
53
53
author = {Han Xiao and
67
67
}
68
68
"""
69
69
70
-
71
70
_K_MNIST_CITATION = """\
72
- @article{DBLP:journals/corr/abs-1812-01718,
73
- author = {Tarin Clanuwat and
74
- Mikel Bober{-}Irizar and
75
- Asanobu Kitamoto and
76
- Alex Lamb and
77
- Kazuaki Yamamoto and
78
- David Ha},
79
- title = {Deep Learning for Classical Japanese Literature},
80
- journal = {CoRR},
81
- volume = {abs/1812.01718},
82
- year = {2018},
83
- url = {http://arxiv.org/abs/1812.01718},
84
- archivePrefix = {arXiv},
85
- eprint = {1812.01718},
86
- timestamp = {Tue, 01 Jan 2019 15:01:25 +0100},
87
- biburl = {https://dblp.org/rec/bib/journals/corr/abs-1812-01718},
88
- bibsource = {dblp computer science bibliography, https://dblp.org}
71
+ @online{clanuwat2018deep,
72
+ author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
73
+ title = {Deep Learning for Classical Japanese Literature},
74
+ date = {2018-12-03},
75
+ year = {2018},
76
+ eprintclass = {cs.CV},
77
+ eprinttype = {arXiv},
78
+ eprint = {cs.CV/1812.01718},
79
+ }
80
+ """
81
+
82
+ _EMNIST_CITATION = """\
83
+ @article{cohen_afshar_tapson_schaik_2017,
84
+ title={EMNIST: Extending MNIST to handwritten letters},
85
+ DOI={10.1109/ijcnn.2017.7966217},
86
+ journal={2017 International Joint Conference on Neural Networks (IJCNN)},
87
+ author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
88
+ year={2017}
89
89
}
90
90
"""
91
91
@@ -118,9 +118,8 @@ def _split_generators(self, dl_manager):
118
118
"test_data" : _MNIST_TEST_DATA_FILENAME ,
119
119
"test_labels" : _MNIST_TEST_LABELS_FILENAME ,
120
120
}
121
- mnist_files = dl_manager .download_and_extract ({
122
- k : urllib .parse .urljoin (self .URL , v ) for k , v in filenames .items ()
123
- })
121
+ mnist_files = dl_manager .download_and_extract (
122
+ {k : urllib .parse .urljoin (self .URL , v ) for k , v in filenames .items ()})
124
123
125
124
# MNIST provides TRAIN and TEST splits, not a VALIDATION split, so we only
126
125
# write the TRAIN and TEST splits to disk.
@@ -181,11 +180,13 @@ def _info(self):
181
180
"grayscale image, associated with a label from 10 "
182
181
"classes." ),
183
182
features = tfds .features .FeaturesDict ({
184
- "image" : tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
185
- "label" : tfds .features .ClassLabel (names = [
186
- "T-shirt/top" , "Trouser" , "Pullover" , "Dress" , "Coat" ,
187
- "Sandal" , "Shirt" , "Sneaker" , "Bag" , "Ankle boot"
188
- ]),
183
+ "image" :
184
+ tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
185
+ "label" :
186
+ tfds .features .ClassLabel (names = [
187
+ "T-shirt/top" , "Trouser" , "Pullover" , "Dress" , "Coat" ,
188
+ "Sandal" , "Shirt" , "Sneaker" , "Bag" , "Ankle boot"
189
+ ]),
189
190
}),
190
191
supervised_keys = ("image" , "label" ),
191
192
urls = ["https://github.com/zalandoresearch/fashion-mnist" ],
@@ -208,17 +209,182 @@ def _info(self):
208
209
"character to represent each of the 10 rows of Hiragana "
209
210
"when creating Kuzushiji-MNIST." ),
210
211
features = tfds .features .FeaturesDict ({
211
- "image" : tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
212
- "label" : tfds .features .ClassLabel (names = [
213
- "o" , "ki" , "su" , "tsu" , "na" , "ha" , "ma" , "ya" , "re" , "wo"
214
- ]),
212
+ "image" :
213
+ tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
214
+ "label" :
215
+ tfds .features .ClassLabel (names = [
216
+ "o" , "ki" , "su" , "tsu" , "na" , "ha" , "ma" , "ya" , "re" , "wo"
217
+ ]),
215
218
}),
216
219
supervised_keys = ("image" , "label" ),
217
220
urls = ["http://codh.rois.ac.jp/kmnist/index.html.en" ],
218
221
citation = _K_MNIST_CITATION ,
219
222
)
220
223
221
224
225
+ class EMNISTConfig (tfds .core .BuilderConfig ):
226
+ """BuilderConfig for EMNIST CONFIG."""
227
+
228
+ @api_utils .disallow_positional_args
229
+ def __init__ (self , class_number , train_examples , test_examples , ** kwargs ):
230
+ """BuilderConfig for EMNIST class number.
231
+
232
+ Args:
233
+ class_number: There are six different splits provided in this dataset. And
234
+ have different class numbers.
235
+ train_examples: number of train examples
236
+ test_examples: number of test examples
237
+ **kwargs: keyword arguments forwarded to super.
238
+ """
239
+ super (EMNISTConfig , self ).__init__ (** kwargs )
240
+ self .class_number = class_number
241
+ self .train_examples = train_examples
242
+ self .test_examples = test_examples
243
+
244
+
245
+ class EMNIST (MNIST ):
246
+ """Emnist dataset."""
247
+
248
+ VERSION = tfds .core .Version ("1.0.1" )
249
+
250
+ BUILDER_CONFIGS = [
251
+ EMNISTConfig (
252
+ name = "byclass" ,
253
+ class_number = 62 ,
254
+ train_examples = 697932 ,
255
+ test_examples = 116323 ,
256
+ description = "EMNIST ByClass: 814,255 characters. 62 unbalanced classes." ,
257
+ version = "1.0.1" ,
258
+ ),
259
+ EMNISTConfig (
260
+ name = "bymerge" ,
261
+ class_number = 47 ,
262
+ train_examples = 697932 ,
263
+ test_examples = 116323 ,
264
+ description = "EMNIST ByMerge: 814,255 characters. 47 unbalanced classes." ,
265
+ version = "1.0.1" ,
266
+ ),
267
+ EMNISTConfig (
268
+ name = "balanced" ,
269
+ class_number = 47 ,
270
+ train_examples = 112800 ,
271
+ test_examples = 18800 ,
272
+ description = "EMNIST Balanced: 131,600 characters. 47 balanced classes." ,
273
+ version = "1.0.1" ,
274
+ ),
275
+ EMNISTConfig (
276
+ name = "letters" ,
277
+ class_number = 37 ,
278
+ train_examples = 88800 ,
279
+ test_examples = 14800 ,
280
+ description = "EMNIST Letters: 103,600 characters. 26 balanced classes." ,
281
+ version = "1.0.1" ,
282
+ ),
283
+ EMNISTConfig (
284
+ name = "digits" ,
285
+ class_number = 10 ,
286
+ train_examples = 240000 ,
287
+ test_examples = 40000 ,
288
+ description = "EMNIST Digits: 280,000 characters. 10 balanced classes." ,
289
+ version = "1.0.1" ,
290
+ ),
291
+ EMNISTConfig (
292
+ name = "mnist" ,
293
+ class_number = 10 ,
294
+ train_examples = 60000 ,
295
+ test_examples = 10000 ,
296
+ description = "EMNIST MNIST: 70,000 characters. 10 balanced classes." ,
297
+ version = "1.0.1" ,
298
+ ),
299
+ EMNISTConfig (
300
+ name = "test" ,
301
+ class_number = 62 ,
302
+ train_examples = 10 ,
303
+ test_examples = 2 ,
304
+ description = "EMNIST test data config." ,
305
+ version = "1.0.1" ,
306
+ ),
307
+ ]
308
+
309
+ def _info (self ):
310
+ return tfds .core .DatasetInfo (
311
+ builder = self ,
312
+ description = (
313
+ "The EMNIST dataset is a set of handwritten character digits"
314
+ "derived from the NIST Special Database 19 and converted to"
315
+ "a 28x28 pixel image format and dataset structure that directly"
316
+ "matches the MNIST dataset." ),
317
+ features = tfds .features .FeaturesDict ({
318
+ "image" :
319
+ tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
320
+ "label" :
321
+ tfds .features .ClassLabel (
322
+ num_classes = self .builder_config .class_number ),
323
+ }),
324
+ supervised_keys = ("image" , "label" ),
325
+ urls = ["https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" ],
326
+ citation = _EMNIST_CITATION ,
327
+ )
328
+
329
+ def _split_generators (self , dl_manager ):
330
+
331
+ filenames = {
332
+ "train_data" :
333
+ "emnist-{}-train-images-idx3-ubyte" .format (
334
+ self .builder_config .name ),
335
+ "train_labels" :
336
+ "emnist-{}-train-labels-idx1-ubyte" .format (
337
+ self .builder_config .name ),
338
+ "test_data" :
339
+ "emnist-{}-test-images-idx3-ubyte" .format (self .builder_config .name ),
340
+ "test_labels" :
341
+ "emnist-{}-test-labels-idx1-ubyte" .format (self .builder_config .name ),
342
+ }
343
+
344
+ dir_name = dl_manager .manual_dir
345
+
346
+ if not tf .io .gfile .exists (os .path .join (dir_name , filenames ["train_data" ])):
347
+ # The current tfds.core.download_manager is unable to
348
+ # extract multiple and nested files.
349
+ # We'll add soon! (Issue 234)
350
+ msg = ("You must download and extract the dataset files manually and "
351
+ "place them in : " )
352
+ msg += dl_manager .manual_dir
353
+ msg += """File tree must be like this :\n
354
+ .
355
+ | -- emnist
356
+ | |-- emnist-byclass-train-images-idx3-ubyte
357
+ | |-- emnist-byclass-train-labels-idx3-ubyte
358
+ | |-- emnist-byclass-test-images-idx3-ubyte
359
+ | |-- emnist-byclass-test-labels-idx3-ubyte
360
+ | |-- emnist-bymerge-train-images-idx3-ubyte
361
+ | |-- emnist-bymerge-train-labels-idx3-ubyte
362
+ | |-- emnist-bymerge-test-images-idx3-ubyte
363
+ | |-- emnist-bymerge-test-labels-idx3-ubyte
364
+ | |-- .......
365
+ """
366
+ raise Exception (msg .replace (" " , "" ))
367
+
368
+ return [
369
+ tfds .core .SplitGenerator (
370
+ name = tfds .Split .TRAIN ,
371
+ num_shards = 10 ,
372
+ gen_kwargs = dict (
373
+ num_examples = self .builder_config .train_examples ,
374
+ data_path = os .path .join (dir_name , filenames ["train_data" ]),
375
+ label_path = os .path .join (dir_name , filenames ["train_labels" ]),
376
+ )),
377
+ tfds .core .SplitGenerator (
378
+ name = tfds .Split .TEST ,
379
+ num_shards = 1 ,
380
+ gen_kwargs = dict (
381
+ num_examples = self .builder_config .test_examples ,
382
+ data_path = os .path .join (dir_name , filenames ["test_data" ]),
383
+ label_path = os .path .join (dir_name , filenames ["test_labels" ]),
384
+ ))
385
+ ]
386
+
387
+
222
388
def _extract_mnist_images (image_filepath , num_images ):
223
389
with tf .io .gfile .GFile (image_filepath , "rb" ) as f :
224
390
f .read (16 ) # header
0 commit comments