|
19 | 19 | from __future__ import division
|
20 | 20 | from __future__ import print_function
|
21 | 21 |
|
| 22 | +import copy |
22 | 23 | import functools
|
23 | 24 | import math
|
24 | 25 | import os
|
@@ -160,34 +161,22 @@ def _make_file_instructions_from_absolutes(
|
160 | 161 | )
|
161 | 162 |
|
162 | 163 |
|
163 |
| -def _read_single_instruction( |
164 |
| - instruction, |
| 164 | +def _read_files( |
| 165 | + files, |
165 | 166 | parse_fn,
|
166 | 167 | read_config,
|
167 |
| - name, |
168 |
| - path, |
169 |
| - split_infos, |
170 | 168 | shuffle_files):
|
171 |
| - """Returns tf.data.Dataset for given instruction. |
| 169 | + """Returns tf.data.Dataset for given file instructions. |
172 | 170 |
|
173 | 171 | Args:
|
174 |
| - instruction (ReadInstruction or str): if str, a ReadInstruction will be |
175 |
| - constructed using `ReadInstruction.from_spec(str)`. |
| 172 | + files: List[dict(filename, skip, take)], the files information. |
| 173 | + The filenames contain the absolute path, not relative. |
| 174 | + skip/take indicates which example read in the shard: `ds.skip().take()` |
176 | 175 | parse_fn (callable): function used to parse each record.
|
177 | 176 | read_config: `tfds.ReadConfig`, Additional options to configure the
|
178 | 177 | input pipeline (e.g. seed, num parallel reads,...).
|
179 |
| - name (str): name of the dataset. |
180 |
| - path (str): path to directory where to read tfrecords from. |
181 |
| - split_infos: `SplitDict`, the `info.splits` container of `SplitInfo`. |
182 | 178 | shuffle_files (bool): Defaults to False. True to shuffle input files.
|
183 | 179 | """
|
184 |
| - file_instructions = make_file_instructions(name, split_infos, instruction) |
185 |
| - for fi in file_instructions.file_instructions: |
186 |
| - fi['filename'] = os.path.join(path, fi['filename']) |
187 |
| - files = file_instructions.file_instructions |
188 |
| - if not files: |
189 |
| - msg = 'Instruction "%s" corresponds to no data!' % instruction |
190 |
| - raise AssertionError(msg) |
191 | 180 | # Eventually apply a transformation to the instruction function.
|
192 | 181 | # This allow the user to have direct control over the interleave order.
|
193 | 182 | if read_config.experimental_interleave_sort_fn is not None:
|
@@ -276,16 +265,51 @@ def read(
|
276 | 265 | ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
|
277 | 266 | corresponding to given instructions param shape.
|
278 | 267 | """
|
279 |
| - read_instruction = functools.partial( |
280 |
| - _read_single_instruction, |
281 |
| - parse_fn=self._parser.parse_example, |
| 268 | + def _read_instruction_to_file_instructions(instruction): |
| 269 | + file_instructions = make_file_instructions(name, split_infos, instruction) |
| 270 | + files = file_instructions.file_instructions |
| 271 | + if not files: |
| 272 | + msg = 'Instruction "%s" corresponds to no data!' % instruction |
| 273 | + raise AssertionError(msg) |
| 274 | + return tuple(files) |
| 275 | + |
| 276 | + files = utils.map_nested( |
| 277 | + _read_instruction_to_file_instructions, instructions, map_tuple=False) |
| 278 | + return utils.map_nested( |
| 279 | + functools.partial( |
| 280 | + self.read_files, read_config=read_config, |
| 281 | + shuffle_files=shuffle_files), |
| 282 | + files, |
| 283 | + map_tuple=False) |
| 284 | + |
| 285 | + def read_files( |
| 286 | + self, |
| 287 | + files, |
| 288 | + read_config, |
| 289 | + shuffle_files |
| 290 | + ): |
| 291 | + """Returns single tf.data.Dataset instance for the set of file instructions. |
| 292 | +
|
| 293 | + Args: |
| 294 | + files: List[dict(filename, skip, take)], the files information. |
| 295 | + The filenames contains the relative path, not absolute. |
| 296 | + skip/take indicates which example read in the shard: `ds.skip().take()` |
| 297 | + read_config: `tfds.ReadConfig`, the input pipeline options |
| 298 | + shuffle_files (bool): If True, input files are shuffled before being read. |
| 299 | +
|
| 300 | + Returns: |
| 301 | + a tf.data.Dataset instance. |
| 302 | + """ |
| 303 | + # Prepend path to filename |
| 304 | + files = copy.deepcopy(files) |
| 305 | + for f in files: |
| 306 | + f.update(filename=os.path.join(self._path, f['filename'])) |
| 307 | + dataset = _read_files( |
| 308 | + files=files, |
282 | 309 | read_config=read_config,
|
283 |
| - split_infos=split_infos, |
284 |
| - name=name, |
285 |
| - path=self._path, |
| 310 | + parse_fn=self._parser.parse_example, |
286 | 311 | shuffle_files=shuffle_files)
|
287 |
| - datasets = utils.map_nested(read_instruction, instructions, map_tuple=True) |
288 |
| - return datasets |
| 312 | + return dataset |
289 | 313 |
|
290 | 314 |
|
291 | 315 | @attr.s(frozen=True)
|
|
0 commit comments