@@ -312,8 +312,10 @@ def as_dataset(self,
312
312
split = None ,
313
313
batch_size = None ,
314
314
shuffle_files = None ,
315
+ decoders = None ,
315
316
as_supervised = False ,
316
317
in_memory = None ):
318
+ # pylint: disable=line-too-long
317
319
"""Constructs a `tf.data.Dataset`.
318
320
319
321
Callers must pass arguments as keyword arguments.
@@ -330,6 +332,9 @@ def as_dataset(self,
330
332
`tf.data.Dataset`.
331
333
shuffle_files: `bool`, whether to shuffle the input files.
332
334
Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
335
+ decoders: Nested dict of `Decoder` objects which allow to customize the
336
+ decoding. The structure should match the feature structure, but only
337
+ customized feature keys need to be present.
333
338
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
334
339
will have a 2-tuple structure `(input, label)` according to
335
340
`builder.info.supervised_keys`. If `False`, the default,
@@ -347,6 +352,7 @@ def as_dataset(self,
347
352
If `batch_size` is -1, will return feature dictionaries containing
348
353
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
349
354
"""
355
+ # pylint: enable=line-too-long
350
356
logging .info ("Constructing tf.data.Dataset for split %s, from %s" ,
351
357
split , self ._data_dir )
352
358
if not tf .io .gfile .exists (self ._data_dir ):
@@ -365,14 +371,21 @@ def as_dataset(self,
365
371
self ._build_single_dataset ,
366
372
shuffle_files = shuffle_files ,
367
373
batch_size = batch_size ,
374
+ decoders = decoders ,
368
375
as_supervised = as_supervised ,
369
376
in_memory = in_memory ,
370
377
)
371
378
datasets = utils .map_nested (build_single_dataset , split , map_tuple = True )
372
379
return datasets
373
380
374
- def _build_single_dataset (self , split , shuffle_files , batch_size ,
375
- as_supervised , in_memory ):
381
+ def _build_single_dataset (
382
+ self ,
383
+ split ,
384
+ shuffle_files ,
385
+ batch_size ,
386
+ decoders ,
387
+ as_supervised ,
388
+ in_memory ):
376
389
"""as_dataset for a single split."""
377
390
if isinstance (split , six .string_types ):
378
391
split = splits_lib .Split (split )
@@ -424,13 +437,15 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
424
437
# If using in_memory, escape all device contexts so we can load the data
425
438
# with a local Session.
426
439
with tf .device (None ):
427
- dataset = self ._as_dataset (split = split , shuffle_files = shuffle_files )
440
+ dataset = self ._as_dataset (
441
+ split = split , shuffle_files = shuffle_files , decoders = decoders )
428
442
# Use padded_batch so that features with unknown shape are supported.
429
443
dataset = dataset .padded_batch (full_bs , dataset .output_shapes )
430
444
dataset = tf .data .Dataset .from_tensor_slices (
431
445
next (dataset_utils .as_numpy (dataset )))
432
446
else :
433
- dataset = self ._as_dataset (split = split , shuffle_files = shuffle_files )
447
+ dataset = self ._as_dataset (
448
+ split = split , shuffle_files = shuffle_files , decoders = decoders )
434
449
435
450
if batch_size :
436
451
# Use padded_batch so that features with unknown shape are supported.
@@ -567,16 +582,18 @@ def _download_and_prepare(self, dl_manager, download_config=None):
567
582
raise NotImplementedError
568
583
569
584
@abc .abstractmethod
570
- def _as_dataset (self , split , shuffle_files = None ):
585
+ def _as_dataset (self , split , decoders = None , shuffle_files = None ):
571
586
"""Constructs a `tf.data.Dataset`.
572
587
573
588
This is the internal implementation to overwrite called when user calls
574
589
`as_dataset`. It should read the pre-processed datasets files and generate
575
590
the `tf.data.Dataset` object.
576
591
577
592
Args:
578
- split (`tfds.Split`): which subset of the data to read.
579
- shuffle_files (bool): whether to shuffle the input files. Optional,
593
+ split: `tfds.Split` which subset of the data to read.
594
+ decoders: Nested structure of `Decoder` object to customize the dataset
595
+ decoding.
596
+ shuffle_files: `bool`, whether to shuffle the input files. Optional,
580
597
defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
581
598
582
599
Returns:
@@ -759,7 +776,12 @@ def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
759
776
# Update the info object with the splits.
760
777
self .info .update_splits_if_different (split_dict )
761
778
762
- def _as_dataset (self , split = splits_lib .Split .TRAIN , shuffle_files = False ):
779
+ def _as_dataset (
780
+ self ,
781
+ split = splits_lib .Split .TRAIN ,
782
+ decoders = None ,
783
+ shuffle_files = False ):
784
+
763
785
if self .version .implements (utils .Experiment .S3 ):
764
786
dataset = self ._tfrecords_reader .read (
765
787
self .name , split , self .info .splits .values (), shuffle_files )
@@ -780,9 +802,11 @@ def _as_dataset(self, split=splits_lib.Split.TRAIN, shuffle_files=False):
780
802
dataset_from_file_fn = self ._file_format_adapter .dataset_from_filename ,
781
803
shuffle_files = shuffle_files ,
782
804
)
805
+
806
+ decode_fn = functools .partial (
807
+ self .info .features .decode_example , decoders = decoders )
783
808
dataset = dataset .map (
784
- self .info .features .decode_example ,
785
- num_parallel_calls = tf .data .experimental .AUTOTUNE )
809
+ decode_fn , num_parallel_calls = tf .data .experimental .AUTOTUNE )
786
810
return dataset
787
811
788
812
def _slice_split_info_to_instruction_dicts (self , list_sliced_split_info ):
0 commit comments