Skip to content

Commit 851eddb

Browse files
adarobcopybara-github
authored andcommitted
Add support for using beam in _split_generators()
PiperOrigin-RevId: 290874617
1 parent 2ec8361 commit 851eddb

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import abc
2323
import functools
24+
import inspect
2425
import itertools
2526
import os
2627
import sys
@@ -853,13 +854,21 @@ def _prepare_split(self, split_generator, **kwargs):
853854
"""
854855
raise NotImplementedError()
855856

857+
def _make_split_generators_kwargs(self, prepare_split_kwargs):
858+
"""Get kwargs for `self._split_generators()` from `prepare_split_kwargs`."""
859+
del prepare_split_kwargs
860+
return {}
861+
856862
def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
857863
if not tf.io.gfile.exists(self._data_dir):
858864
tf.io.gfile.makedirs(self._data_dir)
859865

860866
# Generating data for all splits
861867
split_dict = splits_lib.SplitDict()
862-
for split_generator in self._split_generators(dl_manager):
868+
split_generators_kwargs = self._make_split_generators_kwargs(
869+
prepare_split_kwargs)
870+
for split_generator in self._split_generators(
871+
dl_manager, **split_generators_kwargs):
863872
if splits_lib.Split.ALL == split_generator.split_info.name:
864873
raise ValueError(
865874
"tfds.Split.ALL is a special split keyword corresponding to the "
@@ -1057,6 +1066,18 @@ def __init__(self, *args, **kwargs):
10571066
super(BeamBasedBuilder, self).__init__(*args, **kwargs)
10581067
self._beam_writers = {} # {split: beam_writer} mapping.
10591068

1069+
def _make_split_generators_kwargs(self, prepare_split_kwargs):
1070+
# Pass `pipeline` into `_split_generators()` from `prepare_split_kwargs` if
1071+
# it's in the call signature of `_split_generators()`.
1072+
# This allows for global preprocessing in beam.
1073+
split_generators_kwargs = {}
1074+
split_generators_arg_names = (
1075+
inspect.getargspec(self._split_generators).args if six.PY2 else
1076+
inspect.signature(self._split_generators).parameters.keys())
1077+
if "pipeline" in split_generators_arg_names:
1078+
split_generators_kwargs["pipeline"] = prepare_split_kwargs["pipeline"]
1079+
return split_generators_kwargs
1080+
10601081
@abc.abstractmethod
10611082
def _build_pcollection(self, pipeline, **kwargs):
10621083
"""Build the beam pipeline examples for each `SplitGenerator`.

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,7 @@ def _split_generators(self, dl_manager):
6666
),
6767
]
6868

69-
def _build_pcollection(self, pipeline, num_examples):
70-
"""Generate examples as dicts."""
71-
examples = (
72-
pipeline
73-
| beam.Create(range(num_examples))
74-
| beam.Map(_gen_example)
75-
)
76-
69+
def _compute_metadata(self, examples, num_examples):
7770
self.info.metadata["label_sum_%d" % num_examples] = (
7871
examples
7972
| beam.Map(lambda x: x[1]["label"])
@@ -83,6 +76,14 @@ def _build_pcollection(self, pipeline, num_examples):
8376
| beam.Map(lambda x: x[1]["id"])
8477
| beam.CombineGlobally(beam.combiners.MeanCombineFn()))
8578

79+
def _build_pcollection(self, pipeline, num_examples):
80+
"""Generate examples as dicts."""
81+
examples = (
82+
pipeline
83+
| beam.Create(range(num_examples))
84+
| beam.Map(_gen_example)
85+
)
86+
self._compute_metadata(examples, num_examples)
8687
return examples
8788

8889

@@ -94,6 +95,36 @@ def _gen_example(x):
9495
})
9596

9697

98+
class CommonPipelineDummyBeamDataset(DummyBeamDataset):
99+
100+
def _split_generators(self, dl_manager, pipeline):
101+
del dl_manager
102+
103+
examples = (
104+
pipeline
105+
| beam.Create(range(1000))
106+
| beam.Map(_gen_example)
107+
)
108+
109+
return [
110+
splits_lib.SplitGenerator(
111+
name=splits_lib.Split.TRAIN,
112+
gen_kwargs=dict(examples=examples, num_examples=1000),
113+
),
114+
splits_lib.SplitGenerator(
115+
name=splits_lib.Split.TEST,
116+
gen_kwargs=dict(examples=examples, num_examples=725),
117+
),
118+
]
119+
120+
def _build_pcollection(self, pipeline, examples, num_examples):
121+
"""Generate examples as dicts."""
122+
del pipeline
123+
examples |= beam.Filter(lambda x: x[0] < num_examples)
124+
self._compute_metadata(examples, num_examples)
125+
return examples
126+
127+
97128
class FaultyS3DummyBeamDataset(DummyBeamDataset):
98129

99130
VERSION = utils.Version("1.0.0")
@@ -107,24 +138,24 @@ def test_download_prepare_raise(self):
107138
with self.assertRaisesWithPredicateMatch(ValueError, "no Beam Runner"):
108139
builder.download_and_prepare()
109140

110-
def _assertBeamGeneration(self, dl_config):
141+
def _assertBeamGeneration(self, dl_config, dataset_cls, dataset_name):
111142
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
112-
builder = DummyBeamDataset(data_dir=tmp_dir)
143+
builder = dataset_cls(data_dir=tmp_dir)
113144
builder.download_and_prepare(download_config=dl_config)
114145

115-
data_dir = os.path.join(tmp_dir, "dummy_beam_dataset", "1.0.0")
146+
data_dir = os.path.join(tmp_dir, dataset_name, "1.0.0")
116147
self.assertEqual(data_dir, builder._data_dir)
117148

118149
# Check number of shards
119150
self._assertShards(
120151
data_dir,
121-
pattern="dummy_beam_dataset-test.tfrecord-{:05}-of-{:05}",
152+
pattern="%s-test.tfrecord-{:05}-of-{:05}" % dataset_name,
122153
# Liquid sharding is not guaranteed to always use the same number.
123154
num_shards=builder.info.splits["test"].num_shards,
124155
)
125156
self._assertShards(
126157
data_dir,
127-
pattern="dummy_beam_dataset-train.tfrecord-{:05}-of-{:05}",
158+
pattern="%s-train.tfrecord-{:05}-of-{:05}" % dataset_name,
128159
num_shards=1,
129160
)
130161

@@ -177,7 +208,11 @@ def test_download_prepare(self):
177208
dl_config = self._get_dl_config_if_need_to_run()
178209
if not dl_config:
179210
return
180-
self._assertBeamGeneration(dl_config)
211+
self._assertBeamGeneration(
212+
dl_config, DummyBeamDataset, "dummy_beam_dataset")
213+
self._assertBeamGeneration(
214+
dl_config, CommonPipelineDummyBeamDataset,
215+
"common_pipeline_dummy_beam_dataset")
181216

182217

183218
if __name__ == "__main__":

0 commit comments

Comments
 (0)