Skip to content

Commit 2a7f229

Browse files
pierrot0copybara-github
authored andcommitted
no-op refactorting: file_format_adapter write_from_generator accepts a generator, instead of generator_fn.
PiperOrigin-RevId: 251828067
1 parent 5fd860b commit 2a7f229

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,7 @@ def _prepare_split(self, split_generator, max_examples_per_split):
823823
generator = itertools.islice(generator, max_examples_per_split)
824824
generator = (self.info.features.encode_example(ex) for ex in generator)
825825
output_files = self._build_split_filenames(split_generator.split_info)
826-
self._file_format_adapter.write_from_generator(
827-
lambda: generator, output_files)
826+
self._file_format_adapter.write_from_generator(generator, output_files)
828827

829828

830829
class BeamBasedBuilder(FileAdapterBuilder):

tensorflow_datasets/core/file_format_adapter.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,11 @@ def __init__(self, example_specs):
6969
del example_specs
7070

7171
@abc.abstractmethod
72-
def write_from_generator(self, generator_fn, output_files):
72+
def write_from_generator(self, generator, output_files):
7373
"""Write to files from generators_and_filenames.
7474
7575
Args:
76-
generator_fn: returns generator yielding dictionaries of feature name to
77-
value.
76+
generator: generator yielding dictionaries of feature name to value.
7877
output_files: `list<str>`, output files to write files to.
7978
"""
8079
raise NotImplementedError
@@ -121,10 +120,9 @@ def __init__(self, example_specs):
121120
example_specs)
122121
self._parser = example_parser.ExampleParser(example_specs)
123122

124-
def write_from_generator(self, generator_fn, output_files):
125-
wrapped = (
126-
self._serializer.serialize_example(example)
127-
for example in generator_fn())
123+
def write_from_generator(self, generator, output_files):
124+
wrapped = (self._serializer.serialize_example(example)
125+
for example in generator)
128126
_write_tfrecords_from_generator(wrapped, output_files, shuffle=True)
129127

130128
def write_from_pcollection(self, pcollection, file_path_prefix, num_shards):

tensorflow_datasets/testing/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def features_encode_decode(features_dict, example, as_tensor=False):
296296
file_adapter = file_format_adapter.TFRecordExampleAdapter(
297297
features_dict.get_serialized_info())
298298
file_adapter.write_from_generator(
299-
generator_fn=lambda: [encoded_example],
299+
generator=[encoded_example],
300300
output_files=[tmp_filename],
301301
)
302302
ds = file_adapter.dataset_from_filename(tmp_filename)

0 commit comments

Comments
 (0)