@@ -275,7 +275,8 @@ def as_dataset(self,
275
275
split = None ,
276
276
batch_size = 1 ,
277
277
shuffle_files = None ,
278
- as_supervised = False ):
278
+ as_supervised = False ,
279
+ in_memory = None ):
279
280
"""Constructs a `tf.data.Dataset`.
280
281
281
282
Callers must pass arguments as keyword arguments.
@@ -297,6 +298,12 @@ def as_dataset(self,
297
298
`builder.info.supervised_keys`. If `False`, the default,
298
299
the returned `tf.data.Dataset` will have a dictionary with all the
299
300
features.
301
+ in_memory: `bool`, if `True`, loads the dataset in memory which
302
+ increases iteration speeds. Note that if `True` and the dataset has
303
+ unknown dimensions, the features will be padded to the maximum
304
+ size across the dataset. By default (when `None`), will load the
305
+ dataset in memory if the size is <1GB and all feature dimensions are
306
+ statically known.
300
307
301
308
Returns:
302
309
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
@@ -322,12 +329,13 @@ def as_dataset(self,
322
329
shuffle_files = shuffle_files ,
323
330
batch_size = batch_size ,
324
331
as_supervised = as_supervised ,
332
+ in_memory = in_memory ,
325
333
)
326
334
datasets = utils .map_nested (build_single_dataset , split , map_tuple = True )
327
335
return datasets
328
336
329
337
def _build_single_dataset (self , split , shuffle_files , batch_size ,
330
- as_supervised ):
338
+ as_supervised , in_memory ):
331
339
"""as_dataset for a single split."""
332
340
if isinstance (split , six .string_types ):
333
341
split = splits_lib .Split (split )
@@ -341,10 +349,39 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
341
349
batch_size = self .info .splits .total_num_examples or sys .maxsize
342
350
343
351
dataset = self ._as_dataset (split = split , shuffle_files = shuffle_files )
352
+
353
+ # If the dataset is small, load it in memory
354
+ # TODO(tfds): Expose and use the actual data size on disk and rm the manual
355
+ # name guards. size_in_bytes is the download size, which is misleading,
356
+ # particularly for datasets that use manual_dir as well as some downloads
357
+ # (wmt and diabetic_retinopathy_detection).
358
+ dataset_shape_is_fully_defined = (
359
+ dataset_utils .dataset_shape_is_fully_defined (dataset ))
360
+ in_memory_default = (
361
+ self .info .size_in_bytes and
362
+ self .info .size_in_bytes <= 1e9 and
363
+ not self .name .startswith ("wmt" ) and
364
+ not self .name .startswith ("diabetic" ) and
365
+ dataset_shape_is_fully_defined )
366
+ in_memory = in_memory_default if in_memory is None else in_memory
367
+ if in_memory and not wants_full_dataset :
368
+ # TODO(tfds): Enable in_memory without padding features. May be able
369
+ # to do by using a requested version of tf.data.Dataset.cache that can
370
+ # persist a cache beyond iterator instances.
371
+ if not dataset_shape_is_fully_defined :
372
+ tf .logging .warning ("Called in_memory=True on a dataset that does not "
373
+ "have fully defined shapes. Note that features with "
374
+ "variable length dimensions will be 0-padded to "
375
+ "the maximum length across the dataset." )
376
+ # Use padded_batch so that features with unknown shape are supported.
377
+ full_bs = self .info .splits .total_num_examples or sys .maxsize
378
+ dataset = dataset .padded_batch (full_bs , dataset .output_shapes )
379
+ dataset = tf .data .Dataset .from_tensor_slices (
380
+ next (dataset_utils .as_numpy (dataset )))
381
+
344
382
if batch_size > 1 :
345
383
# Use padded_batch so that features with unknown shape are supported.
346
- padded_shapes = self .info .features .shape
347
- dataset = dataset .padded_batch (batch_size , padded_shapes )
384
+ dataset = dataset .padded_batch (batch_size , dataset .output_shapes )
348
385
349
386
if as_supervised :
350
387
if not self .info .supervised_keys :
0 commit comments