@@ -165,7 +165,8 @@ def _read_files(
165
165
files ,
166
166
parse_fn ,
167
167
read_config ,
168
- shuffle_files ):
168
+ shuffle_files ,
169
+ num_examples ):
169
170
"""Returns tf.data.Dataset for given file instructions.
170
171
171
172
Args:
@@ -176,6 +177,8 @@ def _read_files(
176
177
read_config: `tfds.ReadConfig`, Additional options to configure the
177
178
input pipeline (e.g. seed, num parallel reads,...).
178
179
shuffle_files (bool): Defaults to False. True to shuffle input files.
180
+ num_examples: `int`, if defined, set the cardinality on the
181
+ tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
179
182
"""
180
183
# Eventually apply a transformation to the instruction function.
181
184
# This allow the user to have direct control over the interleave order.
@@ -211,7 +214,13 @@ def _read_files(
211
214
cycle_length = parallel_reads ,
212
215
block_length = block_length ,
213
216
num_parallel_calls = tf .data .experimental .AUTOTUNE ,
214
- )
217
+ )
218
+
219
+ # If the number of examples read in the tf-record is known, we forward
220
+ # the information to the tf.data.Dataset object.
221
+ # Check the `tf.data.experimental` for backward compatibility with TF <= 2.1
222
+ if num_examples and hasattr (tf .data .experimental , 'assert_cardinality' ):
223
+ ds = ds .apply (tf .data .experimental .assert_cardinality (num_examples ))
215
224
216
225
# TODO(tfds): Should merge the default options with read_config to allow users
217
226
# to overwrite the default options.
@@ -265,28 +274,27 @@ def read(
265
274
ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
266
275
corresponding to given instructions param shape.
267
276
"""
268
- def _read_instruction_to_file_instructions (instruction ):
277
+ def _read_instruction_to_ds (instruction ):
269
278
file_instructions = make_file_instructions (name , split_infos , instruction )
270
279
files = file_instructions .file_instructions
271
280
if not files :
272
281
msg = 'Instruction "%s" corresponds to no data!' % instruction
273
282
raise AssertionError (msg )
274
- return tuple (files )
283
+ return self .read_files (
284
+ files = tuple (files ),
285
+ read_config = read_config ,
286
+ shuffle_files = shuffle_files ,
287
+ num_examples = file_instructions .num_examples ,
288
+ )
275
289
276
- files = utils .map_nested (
277
- _read_instruction_to_file_instructions , instructions , map_tuple = False )
278
- return utils .map_nested (
279
- functools .partial (
280
- self .read_files , read_config = read_config ,
281
- shuffle_files = shuffle_files ),
282
- files ,
283
- map_tuple = False )
290
+ return tf .nest .map_structure (_read_instruction_to_ds , instructions )
284
291
285
292
def read_files (
286
293
self ,
287
294
files ,
288
295
read_config ,
289
- shuffle_files
296
+ shuffle_files ,
297
+ num_examples = None ,
290
298
):
291
299
"""Returns single tf.data.Dataset instance for the set of file instructions.
292
300
@@ -296,6 +304,8 @@ def read_files(
296
304
skip/take indicates which example read in the shard: `ds.skip().take()`
297
305
read_config: `tfds.ReadConfig`, the input pipeline options
298
306
shuffle_files (bool): If True, input files are shuffled before being read.
307
+ num_examples: `int`, if defined, set the cardinality on the
308
+ tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
299
309
300
310
Returns:
301
311
a tf.data.Dataset instance.
@@ -308,7 +318,9 @@ def read_files(
308
318
files = files ,
309
319
read_config = read_config ,
310
320
parse_fn = self ._parser .parse_example ,
311
- shuffle_files = shuffle_files )
321
+ shuffle_files = shuffle_files ,
322
+ num_examples = num_examples ,
323
+ )
312
324
return dataset
313
325
314
326
0 commit comments