13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- """MNIST and Fashion MNIST."""
16
+ """MNIST, Fashion MNIST, KMNIST and EMNIST ."""
17
17
18
18
from __future__ import absolute_import
19
19
from __future__ import division
23
23
import six .moves .urllib as urllib
24
24
import tensorflow as tf
25
25
26
+ from tensorflow_datasets .core import api_utils
26
27
import tensorflow_datasets .public_api as tfds
27
28
28
29
# MNIST constants
68
69
"""
69
70
70
71
71
- _K_MNIST_CITATION = """
72
+ _K_MNIST_CITATION = """\
72
73
@online{clanuwat2018deep,
73
74
author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
74
75
title = {Deep Learning for Classical Japanese Literature},
77
78
eprintclass = {cs.CV},
78
79
eprinttype = {arXiv},
79
80
eprint = {cs.CV/1812.01718},
80
- }
81
+ }
82
+ """
83
+
84
+ _EMNIST_CITATION = """\
85
+ @article{cohen_afshar_tapson_schaik_2017,
86
+ title={EMNIST: Extending MNIST to handwritten letters},
87
+ DOI={10.1109/ijcnn.2017.7966217},
88
+ journal={2017 International Joint Conference on Neural Networks (IJCNN)},
89
+ author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van},
90
+ year={2017}
91
+ }
81
92
"""
82
93
83
94
class MNIST (tfds .core .GeneratorBasedBuilder ):
@@ -207,6 +218,143 @@ def _info(self):
207
218
citation = _K_MNIST_CITATION ,
208
219
)
209
220
221
+ class EMNISTConfig (tfds .core .BuilderConfig ):
222
+ """BuilderConfig for EMNIST CONFIG."""
223
+
224
+ @api_utils .disallow_positional_args
225
+ def __init__ (self , class_number , train_examples , test_examples , ** kwargs ):
226
+ """BuilderConfig for EMNIST class number.
227
+
228
+ Args:
229
+ class_number: There are six different splits provided in this dataset. And have
230
+ different class numbers.
231
+
232
+ train_examples, test_examples: So in these have different test and train character
233
+ numbers.
234
+
235
+ **kwargs: keyword arguments forwarded to super.
236
+ """
237
+ super (EMNISTConfig , self ).__init__ (** kwargs )
238
+ self .class_number = class_number
239
+ self .train_examples = train_examples
240
+ self .test_examples = test_examples
241
+
242
+
243
+ class EMNIST (MNIST ):
244
+
245
+ VERSION = tfds .core .Version ('1.0.0' )
246
+
247
+ BUILDER_CONFIGS = [
248
+ EMNISTConfig (
249
+ name = "byclass" ,
250
+ class_number = 62 ,
251
+ train_examples = 697932 ,
252
+ test_examples = 116323 ,
253
+ description = "EMNIST ByClass: 814,255 characters. 62 unbalanced classes." ,
254
+ version = "0.1.1" ,
255
+ ),
256
+ EMNISTConfig (
257
+ name = "bymerge" ,
258
+ class_number = 47 ,
259
+ train_examples = 697932 ,
260
+ test_examples = 116323 ,
261
+ description = "EMNIST ByMerge: 814,255 characters. 47 unbalanced classes." ,
262
+ version = "0.1.1" ,
263
+ ),
264
+ EMNISTConfig (
265
+ name = "balanced" ,
266
+ class_number = 47 ,
267
+ train_examples = 112800 ,
268
+ test_examples = 18800 ,
269
+ description = "EMNIST Balanced: 131,600 characters. 47 balanced classes." ,
270
+ version = "0.1.1" ,
271
+ ),
272
+ EMNISTConfig (
273
+ name = "letters" ,
274
+ class_number = 37 ,
275
+ train_examples = 88800 ,
276
+ test_examples = 14800 ,
277
+ description = "EMNIST Letters: 103,600 characters. 26 balanced classes." ,
278
+ version = "0.1.1" ,
279
+ ),
280
+ EMNISTConfig (
281
+ name = "digits" ,
282
+ class_number = 10 ,
283
+ train_examples = 240000 ,
284
+ test_examples = 40000 ,
285
+ description = "EMNIST Digits: 280,000 characters. 10 balanced classes." ,
286
+ version = "0.1.1" ,
287
+ ),
288
+ EMNISTConfig (
289
+ name = "mnist" ,
290
+ class_number = 10 ,
291
+ train_examples = 60000 ,
292
+ test_examples = 10000 ,
293
+ description = "EMNIST MNIST: 70,000 characters. 10 balanced classes." ,
294
+ version = "0.1.1" ,
295
+ ),
296
+ EMNISTConfig (
297
+ name = "test" ,
298
+ class_number = 62 ,
299
+ train_examples = 10 ,
300
+ test_examples = 2 ,
301
+ description = "EMNIST test data config." ,
302
+ version = "0.1.1" ,
303
+ ),
304
+ ]
305
+
306
+ def _info (self ):
307
+ return tfds .core .DatasetInfo (
308
+ builder = self ,
309
+ description = ("The EMNIST dataset is a set of handwritten character digits"
310
+ "derived from the NIST Special Database 19 and converted to"
311
+ "a 28x28 pixel image format and dataset structure that directly"
312
+ "matches the MNIST dataset."
313
+ ),
314
+ features = tfds .features .FeaturesDict ({
315
+ "image" : tfds .features .Image (shape = _MNIST_IMAGE_SHAPE ),
316
+ "label" : tfds .features .ClassLabel (num_classes = self .builder_config .class_number ),
317
+
318
+ }),
319
+ supervised_keys = ("image" , "label" ),
320
+ urls = ["https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" ],
321
+ citation = _EMNIST_CITATION ,
322
+ )
323
+
324
+ def _split_generators (self , dl_manager ):
325
+
326
+ filenames = {
327
+ "train_data" : 'emnist-{}-train-images-idx3-ubyte' .format (self .builder_config .name ),
328
+ "train_labels" : 'emnist-{}-train-labels-idx1-ubyte' .format (self .builder_config .name ),
329
+ "test_data" : 'emnist-{}-test-images-idx3-ubyte' .format (self .builder_config .name ),
330
+ "test_labels" : 'emnist-{}-test-labels-idx1-ubyte' .format (self .builder_config .name ),
331
+ }
332
+ dir_name = dl_manager .manual_dir
333
+ import os
334
+ return [
335
+ tfds .core .SplitGenerator (
336
+ name = tfds .Split .TRAIN ,
337
+ num_shards = 10 ,
338
+ gen_kwargs = dict (
339
+ num_examples = self .builder_config .train_examples ,
340
+ data_path = os .path .join (dir_name , filenames ['train_data' ]),
341
+ label_path = os .path .join (dir_name , filenames ["train_labels" ]),
342
+ )
343
+
344
+ ),
345
+
346
+ tfds .core .SplitGenerator (
347
+ name = tfds .Split .TEST ,
348
+ num_shards = 1 ,
349
+ gen_kwargs = dict (
350
+ num_examples = self .builder_config .test_examples ,
351
+ data_path = os .path .join (dir_name , filenames ['test_data' ]),
352
+ label_path = os .path .join (dir_name , filenames ["test_labels" ]),
353
+ )
354
+ )
355
+ ]
356
+
357
+
210
358
211
359
212
360
def _extract_mnist_images (image_filepath , num_images ):
@@ -226,8 +374,3 @@ def _extract_mnist_labels(labels_filepath, num_labels):
226
374
buf = f .read (num_labels )
227
375
labels = np .frombuffer (buf , dtype = np .uint8 ).astype (np .int64 )
228
376
return labels
229
-
230
-
231
-
232
- # test file
233
- # and full test
0 commit comments