Skip to content

Commit 4badde4

Browse files
Improve usability of batch_size in as_dataset function.
1 parent fd27f0b commit 4badde4

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def download_and_prepare(self, download_dir=None, download_config=None):
308308
@api_utils.disallow_positional_args
309309
def as_dataset(self,
310310
split=None,
311-
batch_size=1,
311+
batch_size=None,
312312
shuffle_files=None,
313313
as_supervised=False):
314314
"""Constructs a `tf.data.Dataset`.
@@ -376,7 +376,7 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
376376
batch_size = self.info.splits.total_num_examples or sys.maxsize
377377

378378
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
379-
if batch_size > 1:
379+
if batch_size:
380380
# Use padded_batch so that features with unknown shape are supported.
381381
padded_shapes = self.info.features.shape
382382
dataset = dataset.padded_batch(batch_size, padded_shapes)

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,16 @@ def test_with_batch_size(self):
423423
self.assertEqual(10, x2.shape[0])
424424
self.assertEqual(10, x3.shape[0])
425425
self.assertEqual(sum(range(30)), int(x1.sum() + x2.sum() + x3.sum()))
426+
427+
# By default batch_size is None and won't add a batch dimension
428+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN)
429+
self.assertEqual(0, len(ds.output_shapes["x"]))
430+
# Setting batch_size=1 will add an extra batch dimension
431+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN, batch_size=1)
432+
self.assertEqual(1, len(ds.output_shapes["x"]))
433+
# Setting batch_size=2 will add an extra batch dimension
434+
ds = self.builder.as_dataset(split=splits_lib.Split.TRAIN, batch_size=2)
435+
self.assertEqual(1, len(ds.output_shapes["x"]))
426436

427437
@testing.run_in_graph_and_eager_modes()
428438
def test_supervised_keys(self):

tensorflow_datasets/core/registered.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def builder(name, **builder_init_kwargs):
176176
def load(name,
177177
split=None,
178178
data_dir=None,
179-
batch_size=1,
179+
batch_size=None,
180180
download=True,
181181
as_supervised=False,
182182
with_info=False,

tensorflow_datasets/core/registered_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_load(self):
120120
self.assertFalse(builder.download_called)
121121
self.assertEqual(splits.Split.TEST,
122122
builder.as_dataset_kwargs.pop("split"))
123-
self.assertEqual(1, builder.as_dataset_kwargs.pop("batch_size"))
123+
self.assertEqual(None, builder.as_dataset_kwargs.pop("batch_size"))
124124
self.assertFalse(builder.as_dataset_kwargs.pop("as_supervised"))
125125
self.assertEqual(builder.as_dataset_kwargs, as_dataset_kwargs)
126126
self.assertEqual(dict(data_dir=data_dir, k1=1), builder.kwargs)
@@ -130,6 +130,22 @@ def test_load(self):
130130
download=True, as_dataset_kwargs=as_dataset_kwargs)
131131
self.assertTrue(builder.as_dataset_called)
132132
self.assertTrue(builder.download_called)
133+
134+
# Tests for different batch_size
135+
# By default batch_size=None
136+
builder = registered.load(
137+
name=name, split=splits.Split.TEST, data_dir=data_dir)
138+
self.assertEqual(None, builder.as_dataset_kwargs.pop("batch_size"))
139+
# Setting batch_size=1
140+
builder = registered.load(
141+
name=name, split=splits.Split.TEST, data_dir=data_dir,
142+
batch_size=1)
143+
self.assertEqual(1, builder.as_dataset_kwargs.pop("batch_size"))
144+
# Setting batch_size=2
145+
builder = registered.load(
146+
name=name, split=splits.Split.TEST, data_dir=data_dir,
147+
batch_size=2)
148+
self.assertEqual(2, builder.as_dataset_kwargs.pop("batch_size"))
133149

134150
def test_load_all_splits(self):
135151
name = "empty_dataset_builder"

0 commit comments

Comments
 (0)