Skip to content

Commit 34ae6dd

Browse files
adarobcopybara-github
authored andcommitted
Expose API for reading individual TFRecord files.
PiperOrigin-RevId: 292534021
1 parent cf61126 commit 34ae6dd

File tree

2 files changed

+60
-26
lines changed

2 files changed

+60
-26
lines changed

tensorflow_datasets/core/tfrecords_reader.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import copy
2223
import functools
2324
import math
2425
import os
@@ -160,34 +161,22 @@ def _make_file_instructions_from_absolutes(
160161
)
161162

162163

163-
def _read_single_instruction(
164-
instruction,
164+
def _read_files(
165+
files,
165166
parse_fn,
166167
read_config,
167-
name,
168-
path,
169-
split_infos,
170168
shuffle_files):
171-
"""Returns tf.data.Dataset for given instruction.
169+
"""Returns tf.data.Dataset for given file instructions.
172170
173171
Args:
174-
instruction (ReadInstruction or str): if str, a ReadInstruction will be
175-
constructed using `ReadInstruction.from_spec(str)`.
172+
files: List[dict(filename, skip, take)], the files information.
173+
The filenames contain the absolute path, not relative.
174+
skip/take indicates which example read in the shard: `ds.skip().take()`
176175
parse_fn (callable): function used to parse each record.
177176
read_config: `tfds.ReadConfig`, Additional options to configure the
178177
input pipeline (e.g. seed, num parallel reads,...).
179-
name (str): name of the dataset.
180-
path (str): path to directory where to read tfrecords from.
181-
split_infos: `SplitDict`, the `info.splits` container of `SplitInfo`.
182178
shuffle_files (bool): Defaults to False. True to shuffle input files.
183179
"""
184-
file_instructions = make_file_instructions(name, split_infos, instruction)
185-
for fi in file_instructions.file_instructions:
186-
fi['filename'] = os.path.join(path, fi['filename'])
187-
files = file_instructions.file_instructions
188-
if not files:
189-
msg = 'Instruction "%s" corresponds to no data!' % instruction
190-
raise AssertionError(msg)
191180
# Eventually apply a transformation to the instruction function.
192181
# This allow the user to have direct control over the interleave order.
193182
if read_config.experimental_interleave_sort_fn is not None:
@@ -276,16 +265,51 @@ def read(
276265
ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
277266
corresponding to given instructions param shape.
278267
"""
279-
read_instruction = functools.partial(
280-
_read_single_instruction,
281-
parse_fn=self._parser.parse_example,
268+
def _read_instruction_to_file_instructions(instruction):
269+
file_instructions = make_file_instructions(name, split_infos, instruction)
270+
files = file_instructions.file_instructions
271+
if not files:
272+
msg = 'Instruction "%s" corresponds to no data!' % instruction
273+
raise AssertionError(msg)
274+
return tuple(files)
275+
276+
files = utils.map_nested(
277+
_read_instruction_to_file_instructions, instructions, map_tuple=False)
278+
return utils.map_nested(
279+
functools.partial(
280+
self.read_files, read_config=read_config,
281+
shuffle_files=shuffle_files),
282+
files,
283+
map_tuple=False)
284+
285+
def read_files(
286+
self,
287+
files,
288+
read_config,
289+
shuffle_files
290+
):
291+
"""Returns single tf.data.Dataset instance for the set of file instructions.
292+
293+
Args:
294+
files: List[dict(filename, skip, take)], the files information.
295+
The filenames contains the relative path, not absolute.
296+
skip/take indicates which example read in the shard: `ds.skip().take()`
297+
read_config: `tfds.ReadConfig`, the input pipeline options
298+
shuffle_files (bool): If True, input files are shuffled before being read.
299+
300+
Returns:
301+
a tf.data.Dataset instance.
302+
"""
303+
# Prepend path to filename
304+
files = copy.deepcopy(files)
305+
for f in files:
306+
f.update(filename=os.path.join(self._path, f['filename']))
307+
dataset = _read_files(
308+
files=files,
282309
read_config=read_config,
283-
split_infos=split_infos,
284-
name=name,
285-
path=self._path,
310+
parse_fn=self._parser.parse_example,
286311
shuffle_files=shuffle_files)
287-
datasets = utils.map_nested(read_instruction, instructions, map_tuple=True)
288-
return datasets
312+
return dataset
289313

290314

291315
@attr.s(frozen=True)

tensorflow_datasets/core/tfrecords_reader_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,16 @@ def test_4fold(self):
383383
[b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'],
384384
[b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])
385385

386+
def test_read_files(self):
387+
self._write_tfrecord('train', 4, 'abcdefghijkl')
388+
fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
389+
ds = self.reader.read_files(
390+
[{'filename': fname_pattern % 1, 'skip': 0, 'take': -1},
391+
{'filename': fname_pattern % 3, 'skip': 1, 'take': 1}],
392+
read_config=read_config_lib.ReadConfig(),
393+
shuffle_files=False)
394+
read_data = list(tfds.as_numpy(ds))
395+
self.assertEqual(read_data, [six.b(l) for l in 'defk'])
386396

387397
if __name__ == '__main__':
388398
testing.test_main()

0 commit comments

Comments
 (0)