Skip to content

Commit 80ba498

Browse files
Merge pull request #651 from ChanchalKumarMaji:improve-usability-1
PiperOrigin-RevId: 252517354
2 parents 288f00c + 4badde4 commit 80ba498

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def download_and_prepare(self, download_dir=None, download_config=None):
308308
@api_utils.disallow_positional_args
309309
def as_dataset(self,
310310
split=None,
311-
batch_size=1,
311+
batch_size=None,
312312
shuffle_files=None,
313313
as_supervised=False):
314314
"""Constructs a `tf.data.Dataset`.
@@ -320,8 +320,8 @@ def as_dataset(self,
320320
(default), returns all splits in a dict
321321
`<key: tfds.Split, value: tf.data.Dataset>`.
322322
batch_size: `int`, batch size. Note that variable-length features will
323-
be 0-padded if `batch_size > 1`. Users that want more custom behavior
324-
should use `batch_size=1` and use the `tf.data` API to construct a
323+
be 0-padded if `batch_size` is set. Users that want more custom behavior
324+
should use `batch_size=None` and use the `tf.data` API to construct a
325325
custom pipeline. If `batch_size == -1`, will return feature
326326
dictionaries of the whole dataset with `tf.Tensor`s instead of a
327327
`tf.data.Dataset`.
@@ -376,10 +376,9 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
376376
batch_size = self.info.splits.total_num_examples or sys.maxsize
377377

378378
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
379-
if batch_size > 1:
379+
if batch_size:
380380
# Use padded_batch so that features with unknown shape are supported.
381-
padded_shapes = self.info.features.shape
382-
dataset = dataset.padded_batch(batch_size, padded_shapes)
381+
dataset = dataset.padded_batch(batch_size, dataset.output_shapes)
383382

384383
if as_supervised:
385384
if not self.info.supervised_keys:

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_with_configs(self):
228228
splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST]
229229
for builder, incr in [(builder1, 1), (builder2, 2)]:
230230
train_data, test_data = [ # pylint: disable=g-complex-comprehension
231-
[el["x"] for el in
231+
[el["x"] for el in # pylint: disable=g-complex-comprehension
232232
dataset_utils.as_numpy(builder.as_dataset(split=split))]
233233
for split in splits_list
234234
]
@@ -424,6 +424,16 @@ def test_with_batch_size(self):
424424
self.assertEqual(10, x3.shape[0])
425425
self.assertEqual(sum(range(30)), int(x1.sum() + x2.sum() + x3.sum()))
426426

427+
# By default batch_size is None and won't add a batch dimension
428+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN)
429+
self.assertEqual(0, len(ds.output_shapes["x"]))
430+
# Setting batch_size=1 will add an extra batch dimension
431+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN, batch_size=1)
432+
self.assertEqual(1, len(ds.output_shapes["x"]))
433+
# Setting batch_size=2 will add an extra batch dimension
434+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN, batch_size=2)
435+
self.assertEqual(1, len(ds.output_shapes["x"]))
436+
427437
@testing.run_in_graph_and_eager_modes()
428438
def test_supervised_keys(self):
429439
x, _ = dataset_utils.as_numpy(self.builder.as_dataset(

tensorflow_datasets/core/registered.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def builder(name, **builder_init_kwargs):
176176
def load(name,
177177
split=None,
178178
data_dir=None,
179-
batch_size=1,
179+
batch_size=None,
180180
download=True,
181181
as_supervised=False,
182182
with_info=False,
@@ -228,7 +228,7 @@ def load(name,
228228
`tfds.Split.TEST`).
229229
data_dir: `str` (optional), directory to read/write data.
230230
Defaults to "~/tensorflow_datasets".
231-
batch_size: `int`, set to > 1 to get batches of examples. Note that
231+
batch_size: `int`, if set, add a batch dimension to examples. Note that
232232
variable length features will be 0-padded. If
233233
`batch_size=-1`, will return the full dataset as `tf.Tensor`s.
234234
download: `bool` (optional), whether to call

tensorflow_datasets/core/registered_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_load(self):
120120
self.assertFalse(builder.download_called)
121121
self.assertEqual(splits.Split.TEST,
122122
builder.as_dataset_kwargs.pop("split"))
123-
self.assertEqual(1, builder.as_dataset_kwargs.pop("batch_size"))
123+
self.assertEqual(None, builder.as_dataset_kwargs.pop("batch_size"))
124124
self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised"))
125125
self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs)
126126
self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs)
@@ -131,6 +131,16 @@ def test_load(self):
131131
self.assertTrue(builder.as_dataset_called)
132132
self.assertTrue(builder.download_called)
133133

134+
# Tests for different batch_size
135+
# By default batch_size=None
136+
builder = registered.load(
137+
name=name, split=splits.Split.TEST, data_dir=data_dir)
138+
self.assertEqual(None, builder.as_dataset_kwargs.pop("batch_size"))
139+
# Setting batch_size=1
140+
builder = registered.load(
141+
name=name, split=splits.Split.TEST, data_dir=data_dir,
142+
batch_size=1)
143+
134144
def test_load_all_splits(self):
135145
name = "empty_dataset_builder"
136146
# EmptyDatasetBuilder returns self from as_dataset

0 commit comments

Comments
 (0)