Skip to content

Commit 30c11af

Browse files
Conchylicultorcopybara-github
authored andcommitted
Expose tf.data.Dataset cardinality in TFDS
PiperOrigin-RevId: 296016166
1 parent 97cdcb7 commit 30c11af

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, file_instructions):
7373
"""Constructor.
7474
7575
Args:
76-
file_instructions: _FileInstructionOutput
76+
file_instructions: FileInstructions
7777
"""
7878
self._file_instructions = file_instructions
7979

tensorflow_datasets/core/tfrecords_reader.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def _read_files(
165165
files,
166166
parse_fn,
167167
read_config,
168-
shuffle_files):
168+
shuffle_files,
169+
num_examples):
169170
"""Returns tf.data.Dataset for given file instructions.
170171
171172
Args:
@@ -176,6 +177,8 @@ def _read_files(
176177
read_config: `tfds.ReadConfig`, Additional options to configure the
177178
input pipeline (e.g. seed, num parallel reads,...).
178179
shuffle_files (bool): Defaults to False. True to shuffle input files.
180+
num_examples: `int`, if defined, set the cardinality on the
181+
tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
179182
"""
180183
# Eventually apply a transformation to the instruction function.
181184
# This allow the user to have direct control over the interleave order.
@@ -211,7 +214,13 @@ def _read_files(
211214
cycle_length=parallel_reads,
212215
block_length=block_length,
213216
num_parallel_calls=tf.data.experimental.AUTOTUNE,
214-
)
217+
)
218+
219+
# If the number of examples read in the tf-record is known, we forward
220+
# the information to the tf.data.Dataset object.
221+
# Check the `tf.data.experimental` for backward compatibility with TF <= 2.1
222+
if num_examples and hasattr(tf.data.experimental, 'assert_cardinality'):
223+
ds = ds.apply(tf.data.experimental.assert_cardinality(num_examples))
215224

216225
# TODO(tfds): Should merge the default options with read_config to allow users
217226
# to overwrite the default options.
@@ -265,28 +274,27 @@ def read(
265274
ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
266275
corresponding to given instructions param shape.
267276
"""
268-
def _read_instruction_to_file_instructions(instruction):
277+
def _read_instruction_to_ds(instruction):
269278
file_instructions = make_file_instructions(name, split_infos, instruction)
270279
files = file_instructions.file_instructions
271280
if not files:
272281
msg = 'Instruction "%s" corresponds to no data!' % instruction
273282
raise AssertionError(msg)
274-
return tuple(files)
283+
return self.read_files(
284+
files=tuple(files),
285+
read_config=read_config,
286+
shuffle_files=shuffle_files,
287+
num_examples=file_instructions.num_examples,
288+
)
275289

276-
files = utils.map_nested(
277-
_read_instruction_to_file_instructions, instructions, map_tuple=False)
278-
return utils.map_nested(
279-
functools.partial(
280-
self.read_files, read_config=read_config,
281-
shuffle_files=shuffle_files),
282-
files,
283-
map_tuple=False)
290+
return tf.nest.map_structure(_read_instruction_to_ds, instructions)
284291

285292
def read_files(
286293
self,
287294
files,
288295
read_config,
289-
shuffle_files
296+
shuffle_files,
297+
num_examples=None,
290298
):
291299
"""Returns single tf.data.Dataset instance for the set of file instructions.
292300
@@ -296,6 +304,8 @@ def read_files(
296304
skip/take indicates which example read in the shard: `ds.skip().take()`
297305
read_config: `tfds.ReadConfig`, the input pipeline options
298306
shuffle_files (bool): If True, input files are shuffled before being read.
307+
num_examples: `int`, if defined, set the cardinality on the
308+
tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
299309
300310
Returns:
301311
a tf.data.Dataset instance.
@@ -308,7 +318,9 @@ def read_files(
308318
files=files,
309319
read_config=read_config,
310320
parse_fn=self._parser.parse_example,
311-
shuffle_files=shuffle_files)
321+
shuffle_files=shuffle_files,
322+
num_examples=num_examples,
323+
)
312324
return dataset
313325

314326

tensorflow_datasets/core/tfrecords_reader_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from absl.testing import absltest
2727
import six
2828

29+
import tensorflow as tf
30+
2931
import tensorflow_datasets as tfds
3032
from tensorflow_datasets import testing
3133
from tensorflow_datasets.core import example_parser
@@ -35,6 +37,10 @@
3537
from tensorflow_datasets.core.utils import read_config as read_config_lib
3638

3739

40+
# Skip the cardinality test for backward compatibility with TF <= 2.1.
41+
_SKIP_CARDINALITY_TEST = not hasattr(tf.data.experimental, 'assert_cardinality')
42+
43+
3844
class GetDatasetFilesTest(testing.TestCase):
3945

4046
NAME2SHARD_LENGTHS = {
@@ -308,19 +314,34 @@ def test_noskip_notake(self):
308314
read_data = list(tfds.as_numpy(ds))
309315
self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijkl'])
310316

317+
if not _SKIP_CARDINALITY_TEST:
318+
# Check that the cardinality is correctly set.
319+
self.assertEqual(
320+
tf.data.experimental.cardinality(ds).numpy(), len(read_data))
321+
311322
def test_overlap(self):
312323
self._write_tfrecord('train', 5, 'abcdefghijkl')
313324
ds = self.reader.read('mnist', 'train+train[:2]', self.SPLIT_INFOS)
314325
read_data = list(tfds.as_numpy(ds))
315326
self.assertEqual(read_data, [six.b(l) for l in 'abcdefghijklab'])
316327

328+
if not _SKIP_CARDINALITY_TEST:
329+
# Check that the cardinality is correctly set.
330+
self.assertEqual(
331+
tf.data.experimental.cardinality(ds).numpy(), len(read_data))
332+
317333
def test_complex(self):
318334
self._write_tfrecord('train', 5, 'abcdefghijkl')
319335
self._write_tfrecord('test', 3, 'mnopqrs')
320336
ds = self.reader.read('mnist', 'train[1:-1]+test[:-50%]', self.SPLIT_INFOS)
321337
read_data = list(tfds.as_numpy(ds))
322338
self.assertEqual(read_data, [six.b(l) for l in 'bcdefghijkmno'])
323339

340+
if not _SKIP_CARDINALITY_TEST:
341+
# Check that the cardinality is correctly set.
342+
self.assertEqual(
343+
tf.data.experimental.cardinality(ds).numpy(), len(read_data))
344+
324345
def test_shuffle_files(self):
325346
self._write_tfrecord('train', 5, 'abcdefghijkl')
326347
ds = self.reader.read('mnist', 'train', self.SPLIT_INFOS,

0 commit comments

Comments
 (0)