Skip to content

Commit e0cad3d

Browse files
Ryan Sepassicopybara-github
authored andcommitted
Add in_memory option to load and as_dataset for small datasets
Defaults to loading small datasets (<1GB) in memory. Note that to benefit from this, tfds.load should be called just once and the dataset that's returned should be reused. PiperOrigin-RevId: 249295830
1 parent 43176bb commit e0cad3d

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ def as_dataset(self,
275275
split=None,
276276
batch_size=1,
277277
shuffle_files=None,
278-
as_supervised=False):
278+
as_supervised=False,
279+
in_memory=None):
279280
"""Constructs a `tf.data.Dataset`.
280281
281282
Callers must pass arguments as keyword arguments.
@@ -297,6 +298,12 @@ def as_dataset(self,
297298
`builder.info.supervised_keys`. If `False`, the default,
298299
the returned `tf.data.Dataset` will have a dictionary with all the
299300
features.
301+
in_memory: `bool`, if `True`, loads the dataset in memory which
302+
increases iteration speeds. Note that if `True` and the dataset has
303+
unknown dimensions, the features will be padded to the maximum
304+
size across the dataset. By default (when `None`), will load the
305+
dataset in memory if the size is <1GB and all feature dimensions are
306+
statically known.
300307
301308
Returns:
302309
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
@@ -322,12 +329,13 @@ def as_dataset(self,
322329
shuffle_files=shuffle_files,
323330
batch_size=batch_size,
324331
as_supervised=as_supervised,
332+
in_memory=in_memory,
325333
)
326334
datasets = utils.map_nested(build_single_dataset, split, map_tuple=True)
327335
return datasets
328336

329337
def _build_single_dataset(self, split, shuffle_files, batch_size,
330-
as_supervised):
338+
as_supervised, in_memory):
331339
"""as_dataset for a single split."""
332340
if isinstance(split, six.string_types):
333341
split = splits_lib.Split(split)
@@ -341,10 +349,39 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
341349
batch_size = self.info.splits.total_num_examples or sys.maxsize
342350

343351
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
352+
353+
# If the dataset is small, load it in memory
354+
# TODO(tfds): Expose and use the actual data size on disk and rm the manual
355+
# name guards. size_in_bytes is the download size, which is misleading,
356+
# particularly for datasets that use manual_dir as well as some downloads
357+
# (wmt and diabetic_retinopathy_detection).
358+
dataset_shape_is_fully_defined = (
359+
dataset_utils.dataset_shape_is_fully_defined(dataset))
360+
in_memory_default = (
361+
self.info.size_in_bytes and
362+
self.info.size_in_bytes <= 1e9 and
363+
not self.name.startswith("wmt") and
364+
not self.name.startswith("diabetic") and
365+
dataset_shape_is_fully_defined)
366+
in_memory = in_memory_default if in_memory is None else in_memory
367+
if in_memory and not wants_full_dataset:
368+
# TODO(tfds): Enable in_memory without padding features. May be able
369+
# to do by using a requested version of tf.data.Dataset.cache that can
370+
# persist a cache beyond iterator instances.
371+
if not dataset_shape_is_fully_defined:
372+
tf.logging.warning("Called in_memory=True on a dataset that does not "
373+
"have fully defined shapes. Note that features with "
374+
"variable length dimensions will be 0-padded to "
375+
"the maximum length across the dataset.")
376+
# Use padded_batch so that features with unknown shape are supported.
377+
full_bs = self.info.splits.total_num_examples or sys.maxsize
378+
dataset = dataset.padded_batch(full_bs, dataset.output_shapes)
379+
dataset = tf.data.Dataset.from_tensor_slices(
380+
next(dataset_utils.as_numpy(dataset)))
381+
344382
if batch_size > 1:
345383
# Use padded_batch so that features with unknown shape are supported.
346-
padded_shapes = self.info.features.shape
347-
dataset = dataset.padded_batch(batch_size, padded_shapes)
384+
dataset = dataset.padded_batch(batch_size, dataset.output_shapes)
348385

349386
if as_supervised:
350387
if not self.info.supervised_keys:

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,13 @@ def tearDownClass(cls):
414414
def setUp(self):
415415
self.builder = DummyDatasetSharedGenerator(data_dir=self._tfds_tmp_dir)
416416

417+
@testing.run_in_graph_and_eager_modes()
418+
def test_in_memory(self):
419+
train_data = dataset_utils.as_numpy(
420+
self.builder.as_dataset(split="train", in_memory=True))
421+
train_data = [el for el in train_data]
422+
self.assertEqual(20, len(train_data))
423+
417424
@testing.run_in_graph_and_eager_modes()
418425
def test_all_splits(self):
419426
splits = dataset_utils.as_numpy(

tensorflow_datasets/core/dataset_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,13 @@ def _eager_dataset_iterator(dataset):
163163
yield tf.nest.pack_sequence_as(item, flat)
164164

165165

166-
def _graph_dataset_iterator(ds_item, graph=None):
166+
def _graph_dataset_iterator(ds_iter, graph=None):
167+
"""Constructs a Python generator from a tf.data.Iterator."""
168+
with utils.maybe_with_graph(graph, create_if_none=False):
169+
init = ds_iter.initializer
170+
ds_item = ds_iter.get_next()
167171
with utils.nogpu_session(graph) as sess:
172+
sess.run(init)
168173
while True:
169174
try:
170175
yield sess.run(ds_item)
@@ -219,7 +224,7 @@ def as_numpy(dataset, graph=None):
219224
# First create iterators for datasets
220225
with utils.maybe_with_graph(graph, create_if_none=False):
221226
ds_iters = [
222-
tf.compat.v1.data.make_one_shot_iterator(ds_el).get_next()
227+
tf.compat.v1.data.make_initializable_iterator(ds_el)
223228
for ds_el in flat_ds if tf_compat.is_dataset(ds_el)
224229
]
225230
ds_iters = [_graph_dataset_iterator(ds_iter, graph) for ds_iter in ds_iters]
@@ -240,3 +245,8 @@ def as_numpy(dataset, graph=None):
240245

241246
# Nest
242247
return tf.nest.pack_sequence_as(nested_ds, flat_np)
248+
249+
250+
def dataset_shape_is_fully_defined(ds):
251+
return all(
252+
[ts.is_fully_defined() for ts in tf.nest.flatten(ds.output_shapes)])

tensorflow_datasets/core/registered.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def load(name,
177177
split=None,
178178
data_dir=None,
179179
batch_size=1,
180+
in_memory=None,
180181
download=True,
181182
as_supervised=False,
182183
with_info=False,
@@ -231,6 +232,12 @@ def load(name,
231232
batch_size: `int`, set to > 1 to get batches of examples. Note that
232233
variable length features will be 0-padded. If
233234
`batch_size=-1`, will return the full dataset as `tf.Tensor`s.
235+
in_memory: `bool`, if `True`, loads the dataset in memory which
236+
increases iteration speeds. Note that if `True` and the dataset has
237+
unknown dimensions, the features will be padded to the maximum
238+
size across the dataset. By default (when `None`), will load the
239+
dataset in memory if the size is <1GB and all feature dimensions are
240+
statically known.
234241
download: `bool` (optional), whether to call
235242
`tfds.core.DatasetBuilder.download_and_prepare`
236243
before calling `tf.DatasetBuilder.as_dataset`. If `False`, data is
@@ -290,6 +297,7 @@ def load(name,
290297
as_dataset_kwargs["split"] = split
291298
as_dataset_kwargs["as_supervised"] = as_supervised
292299
as_dataset_kwargs["batch_size"] = batch_size
300+
as_dataset_kwargs["in_memory"] = in_memory
293301

294302
ds = dbuilder.as_dataset(**as_dataset_kwargs)
295303
if with_info:

tensorflow_datasets/core/registered_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_load(self):
122122
builder.as_dataset_kwargs.pop("split"))
123123
self.assertEqual(1, builder.as_dataset_kwargs.pop("batch_size"))
124124
self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised"))
125+
self.assertIsNone(builder.as_dataset_kwargs.pop("in_memory"))
125126
self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs)
126127
self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs)
127128

0 commit comments

Comments
 (0)