Skip to content

Commit 41fbb33

Browse files
pierrot0copybara-github
authored andcommitted
make S3 experiment default to True (Issue #737).
PiperOrigin-RevId: 257561767
1 parent 9986589 commit 41fbb33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+460
-410
lines changed

docs/add_dataset.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ Its subclasses implement:
103103
[`DatasetInfo`](api_docs/python/tfds/core/DatasetInfo.md) object
104104
describing the dataset
105105
* `_split_generators`: downloads the source data and defines the dataset splits
106-
* `_generate_examples`: yields examples in the dataset from the source data
106+
* `_generate_examples`: yields `(key, example)` tuples in the dataset from the
107+
source data
107108

108109
This guide will use `GeneratorBasedBuilder`.
109110

@@ -131,7 +132,7 @@ class MyDataset(tfds.core.GeneratorBasedBuilder):
131132

132133
def _generate_examples(self):
133134
# Yields examples from the dataset
134-
pass # TODO
135+
yield 'key', {}
135136
```
136137

137138
If you'd like to follow a test-driven development workflow, which can help you
@@ -229,15 +230,13 @@ through [`tfds.Split.subsplit`](splits.md#subsplit).
229230
return [
230231
tfds.core.SplitGenerator(
231232
name=tfds.Split.TRAIN,
232-
num_shards=10,
233233
gen_kwargs={
234234
"images_dir_path": os.path.join(extracted_path, "train"),
235235
"labels": os.path.join(extracted_path, "train_labels.csv"),
236236
},
237237
),
238238
tfds.core.SplitGenerator(
239239
name=tfds.Split.TEST,
240-
num_shards=1,
241240
gen_kwargs={
242241
"images_dir_path": os.path.join(extracted_path, "test"),
243242
"labels": os.path.join(extracted_path, "test_labels.csv"),
@@ -250,10 +249,6 @@ through [`tfds.Split.subsplit`](splits.md#subsplit).
250249
will be passed as keyword arguments to `_generate_examples`, which we'll define
251250
next.
252251

253-
When specifying `num_shards`, which determines how many files the split will
254-
use, pick a number such that a single shard is less that 4 GiB as
255-
as each shard will be loaded in memory for shuffling.
256-
257252
## Writing an example generator
258253

259254
`_generate_examples` generates the examples for each split from the
@@ -268,8 +263,8 @@ builder._generate_examples(
268263
```
269264

270265
This method will typically read source dataset artifacts (e.g. a CSV file) and
271-
yield feature dictionaries that correspond to the features specified in
272-
`DatasetInfo`.
266+
yield (key, feature dictionary) tuples that correspond to the features specified
267+
in `DatasetInfo`.
273268

274269
```python
275270
def _generate_examples(self, images_dir_path, labels):
@@ -281,7 +276,7 @@ def _generate_examples(self, images_dir_path, labels):
281276

282277
# And yield examples as feature dictionaries
283278
for image_id, description, label in data:
284-
yield {
279+
yield image_id, {
285280
"image_description": description,
286281
"image": "%s/%s.jpeg" % (images_dir_path, image_id),
287282
"label": label,
@@ -293,6 +288,10 @@ format suitable for writing to disk (currently we use `tf.train.Example`
293288
protocol buffers). For example, `tfds.features.Image` will copy out the
294289
JPEG content of the passed image files automatically.
295290

291+
The key (here: `image_id`) should uniquely identify the record. It is used to
292+
shuffle the dataset globally. If two records are yielded using the same key,
293+
an exception will be raised during preparation of the dataset.
294+
296295
If you've implemented the test harness, your builder test should now pass.
297296

298297
### File access and `tf.io.gfile`

docs/beam_datasets.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ a look at the
6767
```python
6868
class DummyBeamDataset(tfds.core.BeamBasedBuilder):
6969

70-
VERSION = tfds.core.Version('1.0.0')
70+
# BeamBasedBuilder does not support S3 yet.
71+
VERSION = tfds.core.Version(
72+
'1.0.0', experiments={tfds.core.Experiment.S3: False})
7173

7274
def _info(self):
7375
return tfds.core.DatasetInfo(

docs/datasets_versioning.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ class MNIST(tfds.core.GeneratorBasedBuilder):
111111
VERSION = tfds.core.Version("1.0.0")
112112
SUPPORTED_VERSIONS = [
113113
tfds.core.Version("2.0.0", experiments={tfds.core.Experiment.S3: True}),
114-
tfds.core.Version("1.0.0"),
115114
]
116115
# Version history:
117116
# 2.0.0: S3 (new shuffling, sharding and slicing mechanism).
@@ -123,10 +122,10 @@ definition would then look like:
123122

124123
```py
125124
class MNIST(tfds.core.GeneratorBasedBuilder):
126-
VERSION = tfds.core.Version("2.0.0")
125+
VERSION = tfds.core.Version("1.0.0",
126+
experiments={tfds.core.Experiment.S3: False})
127127
SUPPORTED_VERSIONS = [
128128
tfds.core.Version("2.0.0"),
129-
tfds.core.Version("1.0.0", experiments={tfds.core.Experiment.S3: False}),
130129
]
131130
# Version history:
132131
# 2.0.0: S3 (new shuffling, sharding and slicing mechanism), order of records

tensorflow_datasets/audio/groove.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def __init__(self, split_bars=None, include_audio=True, audio_rate=16000,
7777
else:
7878
name_parts.append("midionly")
7979

80-
super(GrooveConfig, self).__init__(name="-".join(name_parts), **kwargs)
80+
super(GrooveConfig, self).__init__(
81+
name="-".join(name_parts),
82+
version=tfds.core.Version(
83+
"1.0.0", experiments={tfds.core.Experiment.S3: False}),
84+
**kwargs)
8185
self.split_bars = split_bars
8286
self.include_audio = include_audio
8387
self.audio_rate = audio_rate
@@ -89,30 +93,25 @@ class Groove(tfds.core.GeneratorBasedBuilder):
8993
BUILDER_CONFIGS = [
9094
GrooveConfig(
9195
include_audio=False,
92-
version="1.0.0",
9396
description="Groove dataset without audio, unsplit."
9497
),
9598
GrooveConfig(
9699
include_audio=True,
97-
version="1.0.0",
98100
description="Groove dataset with audio, unsplit."
99101
),
100102
GrooveConfig(
101103
include_audio=False,
102104
split_bars=2,
103-
version="1.0.0",
104105
description="Groove dataset without audio, split into 2-bar chunks."
105106
),
106107
GrooveConfig(
107108
include_audio=True,
108109
split_bars=2,
109-
version="1.0.0",
110110
description="Groove dataset with audio, split into 2-bar chunks."
111111
),
112112
GrooveConfig(
113113
include_audio=False,
114114
split_bars=4,
115-
version="1.0.0",
116115
description="Groove dataset without audio, split into 4-bar chunks."
117116
),
118117
]

tensorflow_datasets/audio/librispeech.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,14 @@ def _make_builder_configs():
149149
encoder_cls=tfds.features.text.SubwordTextEncoder,
150150
vocab_size=2**15),
151151
]
152-
version = "0.1.0"
153152
configs = []
154153
for text_encoder_config in text_encoder_configs:
155154
for data in _DATA_OPTIONS:
156155
config = LibrispeechConfig(
157-
version=version, text_encoder_config=text_encoder_config, data=data)
156+
version=tfds.core.Version(
157+
"0.0.1", experiments={tfds.core.Experiment.S3: False}),
158+
text_encoder_config=text_encoder_config,
159+
data=data)
158160
configs.append(config)
159161
return configs
160162

tensorflow_datasets/audio/nsynth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@
8787
class Nsynth(tfds.core.GeneratorBasedBuilder):
8888
"""A large-scale and high-quality dataset of annotated musical notes."""
8989

90-
VERSION = tfds.core.Version("1.0.0")
90+
VERSION = tfds.core.Version("1.0.0",
91+
experiments={tfds.core.Experiment.S3: False})
9192

9293
def _info(self):
9394
return tfds.core.DatasetInfo(

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
class DummyBeamDataset(dataset_builder.BeamBasedBuilder):
4141

42-
VERSION = utils.Version("1.0.0")
42+
VERSION = utils.Version("1.0.0", experiments={utils.Experiment.S3: False})
4343

4444
def _info(self):
4545

@@ -86,6 +86,11 @@ def _gen_example(x):
8686
}
8787

8888

89+
class FaultyS3DummyBeamDataset(DummyBeamDataset):
90+
91+
VERSION = utils.Version("1.0.0")
92+
93+
8994
class BeamBasedBuilderTest(testing.TestCase):
9095

9196
def test_download_prepare_raise(self):
@@ -147,20 +152,35 @@ def _assertElemsAllEqual(self, nested_lhs, nested_rhs):
147152
self.assertAllEqual(lhs, rhs)
148153

149154

150-
# The default beam pipeline do not works with Python2
151-
def test_download_prepare(self):
152-
155+
def _get_dl_config_if_need_to_run(self):
156+
# The default beam pipeline do not works with Python2
153157
# TODO(b/129148632): The current apache-beam 2.11.0 do not work with Py3
154158
# Update once the new version is out (around April)
155159
skip_beam_test = bool(six.PY3)
156160
if skip_beam_test:
157161
return
158-
159-
dl_config = download.DownloadConfig(
162+
return download.DownloadConfig(
160163
beam_options=beam.options.pipeline_options.PipelineOptions(),
161164
)
165+
166+
def test_download_prepare(self):
167+
dl_config = self._get_dl_config_if_need_to_run()
168+
if not dl_config:
169+
return
162170
self._assertBeamGeneration(dl_config)
163171

172+
def test_s3_raise(self):
173+
dl_config = self._get_dl_config_if_need_to_run()
174+
if not dl_config:
175+
return
176+
dl_config.compute_stats = download.ComputeStatsMode.SKIP
177+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
178+
builder = FaultyS3DummyBeamDataset(data_dir=tmp_dir)
179+
builder.download_and_prepare(download_config=dl_config)
180+
with self.assertRaisesWithPredicateMatch(
181+
AssertionError, "`DatasetInfo.SplitInfo.num_shards` is empty"):
182+
builder.as_dataset()
183+
164184

165185
if __name__ == "__main__":
166186
testing.test_main()

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ class DummyDatasetWithConfigs(dataset_builder.GeneratorBasedBuilder):
5252
BUILDER_CONFIGS = [
5353
DummyBuilderConfig(
5454
name="plus1",
55-
version="0.0.1",
55+
version=utils.Version("0.0.1"),
5656
description="Add 1 to the records",
5757
increment=1),
5858
DummyBuilderConfig(
5959
name="plus2",
60-
version="0.0.2",
61-
supported_versions=["0.0.1"],
60+
version=utils.Version("0.0.2"),
61+
supported_versions=[utils.Version("0.0.1")],
6262
description="Add 2 to the records",
6363
increment=2),
6464
]
@@ -70,12 +70,10 @@ def _split_generators(self, dl_manager):
7070
return [
7171
splits_lib.SplitGenerator(
7272
name=splits_lib.Split.TRAIN,
73-
num_shards=2,
7473
gen_kwargs={"range_": range(20)},
7574
),
7675
splits_lib.SplitGenerator(
7776
name=splits_lib.Split.TEST,
78-
num_shards=1,
7977
gen_kwargs={"range_": range(20, 30)},
8078
),
8179
]
@@ -90,9 +88,10 @@ def _info(self):
9088

9189
def _generate_examples(self, range_):
9290
for i in range_:
91+
x = i
9392
if self.builder_config:
94-
i += self.builder_config.increment
95-
yield {"x": i}
93+
x += self.builder_config.increment
94+
yield i, {"x": x}
9695

9796

9897
class InvalidSplitDataset(DummyDatasetWithConfigs):
@@ -143,8 +142,8 @@ def test_determinism(self):
143142
# deterministically generated.
144143
self.assertEqual(
145144
[e["x"] for e in ds_values],
146-
[16, 1, 2, 3, 10, 17, 0, 11, 14, 7, 4, 9, 18, 15, 8, 19, 6, 13, 12,
147-
5],
145+
[6, 16, 19, 12, 14, 18, 5, 13, 15, 4, 10, 17, 0, 8, 3, 1, 9, 7, 11,
146+
2],
148147
)
149148

150149
@testing.run_in_graph_and_eager_modes()
@@ -153,7 +152,7 @@ def test_multi_split(self):
153152
ds_train, ds_test = registered.load(
154153
name="dummy_dataset_shared_generator",
155154
data_dir=tmp_dir,
156-
split=[splits_lib.Split.TRAIN, splits_lib.Split.TEST],
155+
split=["train", "test"],
157156
as_dataset_kwargs=dict(shuffle_files=False))
158157

159158
data = list(dataset_utils.as_numpy(ds_train))
@@ -220,12 +219,12 @@ def test_with_configs(self):
220219
# Test that subdirectories were created per config
221220
self.assertTrue(tf.io.gfile.exists(data_dir1))
222221
self.assertTrue(tf.io.gfile.exists(data_dir2))
223-
# 2 train shards, 1 test shard, plus metadata files
224-
self.assertGreater(len(tf.io.gfile.listdir(data_dir1)), 3)
225-
self.assertGreater(len(tf.io.gfile.listdir(data_dir2)), 3)
222+
# 1 train shard, 1 test shard, plus metadata files
223+
self.assertGreater(len(tf.io.gfile.listdir(data_dir1)), 2)
224+
self.assertGreater(len(tf.io.gfile.listdir(data_dir2)), 2)
226225

227226
# Test that the config was used and they didn't collide.
228-
splits_list = [splits_lib.Split.TRAIN, splits_lib.Split.TEST]
227+
splits_list = ["train", "test"]
229228
for builder, incr in [(builder1, 1), (builder2, 2)]:
230229
train_data, test_data = [ # pylint: disable=g-complex-comprehension
231230
[el["x"] for el in # pylint: disable=g-complex-comprehension
@@ -301,23 +300,24 @@ def load_mnist_dataset_info(self):
301300
def test_stats_restored_from_gcs(self):
302301
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
303302
builder = testing.DummyMnist(data_dir=tmp_dir)
304-
self.assertEqual(builder.info.splits.total_num_examples, 70000)
303+
self.assertEqual(builder.info.splits.total_num_examples, 40)
305304
self.assertFalse(self.compute_dynamic_property.called)
306305

307306
builder.download_and_prepare()
308307

309308
# Statistics shouldn't have been recomputed
310-
self.assertEqual(builder.info.splits.total_num_examples, 70000)
309+
self.assertEqual(builder.info.splits.total_num_examples, 40)
311310
self.assertFalse(self.compute_dynamic_property.called)
312311

313312
def test_stats_not_restored_gcs_overwritten(self):
314313
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
315314
# If split are different that the one restored, stats should be recomputed
316-
builder = testing.DummyMnist(data_dir=tmp_dir, num_shards=5)
317-
self.assertEqual(builder.info.splits.total_num_examples, 70000)
315+
builder = testing.DummyMnist(data_dir=tmp_dir)
316+
self.assertEqual(builder.info.splits.total_num_examples, 40)
318317
self.assertFalse(self.compute_dynamic_property.called)
319318

320-
builder.download_and_prepare()
319+
dl_config = download.DownloadConfig(max_examples_per_split=5)
320+
builder.download_and_prepare(download_config=dl_config)
321321

322322
# Statistics should have been recomputed (split different from the
323323
# restored ones)
@@ -347,7 +347,7 @@ def test_skip_stats(self):
347347
self.patch_gcs.stop()
348348
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
349349
# No dataset_info restored, so stats are empty
350-
builder = testing.DummyMnist(data_dir=tmp_dir, num_shards=5)
350+
builder = testing.DummyMnist(data_dir=tmp_dir)
351351
self.assertEqual(builder.info.splits.total_num_examples, 0)
352352
self.assertFalse(self.compute_dynamic_property.called)
353353

@@ -366,8 +366,8 @@ def test_force_stats(self):
366366

367367
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
368368
# No dataset_info restored, so stats are empty
369-
builder = testing.DummyMnist(data_dir=tmp_dir, num_shards=5)
370-
self.assertEqual(builder.info.splits.total_num_examples, 70000)
369+
builder = testing.DummyMnist(data_dir=tmp_dir)
370+
self.assertEqual(builder.info.splits.total_num_examples, 40)
371371
self.assertFalse(self.compute_dynamic_property.called)
372372

373373
download_config = download.DownloadConfig(
@@ -433,7 +433,7 @@ def test_all_splits(self):
433433
@testing.run_in_graph_and_eager_modes()
434434
def test_with_batch_size(self):
435435
items = list(dataset_utils.as_numpy(self.builder.as_dataset(
436-
split=splits_lib.Split.TRAIN + splits_lib.Split.TEST, batch_size=10)))
436+
split="train+test", batch_size=10)))
437437
# 3 batches of 10
438438
self.assertEqual(3, len(items))
439439
x1, x2, x3 = items[0]["x"], items[1]["x"], items[2]["x"]

0 commit comments

Comments
 (0)