Skip to content

Commit cf61126

Browse files
Conchylicultorcopybara-github
authored andcommitted
Expose the subplits num_examples and instructions to the public API
PiperOrigin-RevId: 292380707
1 parent 5d034f0 commit cf61126

File tree

9 files changed

+233
-76
lines changed

9 files changed

+233
-76
lines changed

tensorflow_datasets/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tensorflow_datasets.core.splits import SplitDict
4040
from tensorflow_datasets.core.splits import SplitGenerator
4141
from tensorflow_datasets.core.splits import SplitInfo
42+
from tensorflow_datasets.core.splits import SubSplitInfo
4243

4344
from tensorflow_datasets.core.tfrecords_reader import ReadInstruction
4445

tensorflow_datasets/core/dataset_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
870870
tf.io.gfile.makedirs(self._data_dir)
871871

872872
# Generating data for all splits
873-
split_dict = splits_lib.SplitDict()
873+
split_dict = splits_lib.SplitDict(dataset_name=self.name)
874874
split_generators_kwargs = self._make_split_generators_kwargs(
875875
prepare_split_kwargs)
876876
for split_generator in self._split_generators(

tensorflow_datasets/core/dataset_info.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(self,
151151
"the top-level. Got {}".format(features))
152152
features._set_top_level() # pylint: disable=protected-access
153153
self._features = features
154-
self._splits = splits_lib.SplitDict()
154+
self._splits = splits_lib.SplitDict(self._builder.name)
155155
if supervised_keys is not None:
156156
assert isinstance(supervised_keys, tuple)
157157
assert len(supervised_keys) == 2
@@ -203,6 +203,10 @@ def homepage(self):
203203
def citation(self):
204204
return self.as_proto.citation
205205

206+
@property
207+
def data_dir(self):
208+
return self._builder.data_dir
209+
206210
@property
207211
def size_in_bytes(self):
208212
size_in_bytes = sum(split.num_bytes for split in self.splits.values())
@@ -362,7 +366,8 @@ def read_from_directory(self, dataset_info_dir):
362366
parsed_proto = read_from_json(json_filename)
363367

364368
# Update splits
365-
self._set_splits(splits_lib.SplitDict.from_proto(parsed_proto.splits))
369+
split_dict = splits_lib.SplitDict.from_proto(self.name, parsed_proto.splits)
370+
self._set_splits(split_dict)
366371

367372
# Restore the feature metadata (vocabulary, labels names,...)
368373
if self.features:

tensorflow_datasets/core/naming.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,21 @@ def filepattern_for_dataset_split(dataset_name, split, data_dir,
6767
return "%s*" % filepath
6868

6969

70-
def filepaths_for_dataset_split(dataset_name, split, num_shards, data_dir,
71-
filetype_suffix=None):
70+
def filenames_for_dataset_split(
71+
dataset_name, split, num_shards, filetype_suffix=None):
7272
prefix = filename_prefix_for_split(dataset_name, split)
7373
if filetype_suffix:
7474
prefix += ".%s" % filetype_suffix
75-
filenames = sharded_filenames(prefix, num_shards)
75+
return sharded_filenames(prefix, num_shards)
76+
77+
78+
def filepaths_for_dataset_split(dataset_name, split, num_shards, data_dir,
79+
filetype_suffix=None):
80+
filenames = filenames_for_dataset_split(
81+
dataset_name=dataset_name,
82+
split=split,
83+
num_shards=num_shards,
84+
filetype_suffix=filetype_suffix,
85+
)
7686
filepaths = [os.path.join(data_dir, fname) for fname in filenames]
7787
return filepaths

tensorflow_datasets/core/naming_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def test_filename_prefix_for_split(self, prefix, expected):
6464
split = splits.Split.TRAIN
6565
self.assertEqual(expected, naming.filename_prefix_for_split(prefix, split))
6666

67+
def test_filenames_for_dataset_split(self):
68+
self.assertEqual([
69+
"foo-train-00000-of-00002",
70+
"foo-train-00001-of-00002",
71+
], naming.filenames_for_dataset_split(
72+
dataset_name="foo",
73+
split=splits.Split.TRAIN,
74+
num_shards=2))
75+
6776
def test_filepaths_for_dataset_split(self):
6877
self.assertEqual([
6978
"/tmp/bar/foo-train-00000-of-00002",

tensorflow_datasets/core/splits.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from six.moves import range # pylint: disable=redefined-builtin
2828

2929
from tensorflow_datasets.core import proto
30+
from tensorflow_datasets.core import tfrecords_reader
3031
from tensorflow_datasets.core import utils
3132

3233

@@ -45,6 +46,37 @@ def __repr__(self):
4546
return "<tfds.core.SplitInfo num_examples=%s>" % str(num_examples)
4647

4748

49+
class SubSplitInfo(object):
50+
"""Wrapper around a sub split info.
51+
52+
This class expose info on the subsplit:
53+
54+
```
55+
ds, info = tfds.load(..., split='train[75%:]', with_info=True)
56+
info.splits['train[75%:]'].num_examples
57+
```
58+
59+
"""
60+
61+
def __init__(self, file_instructions):
62+
"""Constructor.
63+
64+
Args:
65+
file_instructions: _FileInstructionOutput
66+
"""
67+
self._file_instructions = file_instructions
68+
69+
@property
70+
def num_examples(self):
71+
"""Returns the number of example in the subsplit."""
72+
return self._file_instructions.num_examples
73+
74+
@property
75+
def file_instructions(self):
76+
"""Returns the list of dict(filename, take, skip)."""
77+
return self._file_instructions.file_instructions
78+
79+
4880
@six.add_metaclass(abc.ABCMeta)
4981
class SplitBase(object):
5082
# pylint: disable=line-too-long
@@ -527,14 +559,22 @@ def compute_mask_offsets(shard_id2num_examples):
527559
class SplitDict(utils.NonMutableDict):
528560
"""Split info object."""
529561

530-
def __init__(self):
562+
def __init__(self, dataset_name):
531563
super(SplitDict, self).__init__(error_msg="Split {key} already present")
564+
self._dataset_name = dataset_name
532565

533566
def __getitem__(self, key):
534-
if str(key) not in self:
535-
raise KeyError("Invalid split %s. Available splits are: %s" % (
536-
key, sorted(list(self.keys()))))
537-
return super(SplitDict, self).__getitem__(str(key))
567+
# 1st case: The key exists: `info.splits['train']`
568+
if str(key) in self:
569+
return super(SplitDict, self).__getitem__(str(key))
570+
# 2nd case: Uses instructions: `info.splits['train[50%]']`
571+
else:
572+
instructions = tfrecords_reader.make_file_instructions(
573+
name=self._dataset_name,
574+
split_infos=self.values(),
575+
instruction=key,
576+
)
577+
return SubSplitInfo(instructions)
538578

539579
def __setitem__(self, key, value):
540580
raise ValueError("Cannot add elem. Use .add() instead.")
@@ -547,9 +587,9 @@ def add(self, split_info):
547587
super(SplitDict, self).__setitem__(split_info.name, split_info)
548588

549589
@classmethod
550-
def from_proto(cls, repeated_split_infos):
590+
def from_proto(cls, dataset_name, repeated_split_infos):
551591
"""Returns a new SplitDict initialized from the `repeated_split_infos`."""
552-
split_dict = cls()
592+
split_dict = cls(dataset_name)
553593
for split_info_proto in repeated_split_infos:
554594
split_info = SplitInfo()
555595
split_info.CopyFrom(split_info_proto)
@@ -567,7 +607,7 @@ def total_num_examples(self):
567607
return sum(s.num_examples for s in self.values())
568608

569609
def copy(self):
570-
return SplitDict.from_proto(self.to_proto())
610+
return SplitDict.from_proto(self._dataset_name, self.to_proto())
571611

572612

573613
def check_splits_equals(splits1, splits2):

tensorflow_datasets/core/splits_test.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class SplitsUnitTest(testing.TestCase):
9090
@classmethod
9191
def setUpClass(cls):
9292
super(SplitsUnitTest, cls).setUpClass()
93-
cls._splits = tfds.core.SplitDict()
93+
cls._splits = tfds.core.SplitDict("ds_name")
9494
cls._splits.add(tfds.core.SplitInfo(name="train", num_shards=10))
9595
cls._splits.add(tfds.core.SplitInfo(name="test", num_shards=2))
9696
cls._splits.add(tfds.core.SplitInfo(name="custom", num_shards=2))
@@ -270,7 +270,8 @@ def test_split_equality(self):
270270
self.assertNotEqual(train, train.subsplit(tfds.percent[:50]))
271271
self.assertNotEqual(train.subsplit(tfds.percent[:50]), train)
272272

273-
self.assertFalse(tfds.Split.TRAIN != "train")
273+
# Explictly want to test the `!=` operator.
274+
self.assertFalse(tfds.Split.TRAIN != "train") # pylint: disable=g-generic-assert
274275

275276
def _info(self, split):
276277
read_instruction = split.get_read_instruction(self._splits)
@@ -328,6 +329,7 @@ class SplitsOffsetIntegrationTest(testing.TestCase):
328329

329330
@classmethod
330331
def setUpClass(cls):
332+
super(SplitsOffsetIntegrationTest, cls).setUpClass()
331333
cls._builder = DummyDataset(
332334
data_dir=testing.make_tmp_dir(),
333335
range_train=range(0, 666),
@@ -375,6 +377,7 @@ class SplitsIntegrationTest(testing.TestCase):
375377

376378
@classmethod
377379
def setUpClass(cls):
380+
super(SplitsIntegrationTest, cls).setUpClass()
378381
cls._builder = DummyDataset(data_dir=testing.make_tmp_dir())
379382
cls._builder.download_and_prepare()
380383

@@ -506,7 +509,7 @@ class SplitsDictTest(testing.TestCase):
506509

507510
@property
508511
def split_dict(self):
509-
sd = splits.SplitDict()
512+
sd = splits.SplitDict("ds_name")
510513
sd.add(tfds.core.SplitInfo(name="train", num_shards=10))
511514
sd.add(tfds.core.SplitInfo(name="test", num_shards=1))
512515
return sd
@@ -519,10 +522,10 @@ def test_get(self):
519522

520523
def test_from_proto(self):
521524
sd = splits.SplitDict.from_proto(
522-
[proto.SplitInfo(name="validation", num_shards=5)])
523-
self.assertTrue("validation" in sd)
524-
self.assertFalse("train" in sd)
525-
self.assertFalse("test" in sd)
525+
"ds_name", [proto.SplitInfo(name="validation", num_shards=5)])
526+
self.assertIn("validation", sd)
527+
self.assertNotIn("train", sd)
528+
self.assertNotIn("test", sd)
526529

527530
def test_to_proto(self):
528531
sd = self.split_dict
@@ -535,26 +538,26 @@ def test_to_proto(self):
535538
self.assertEqual(10, sdp[1].num_shards)
536539

537540
def test_bool(self):
538-
sd = splits.SplitDict()
541+
sd = splits.SplitDict("ds_name")
539542
self.assertFalse(sd) # Empty split is False
540543
sd.add(tfds.core.SplitInfo(name="train", num_shards=10))
541544
self.assertTrue(sd) # Non-empty split is True
542545

543546
def test_check_splits_equals(self):
544-
s1 = splits.SplitDict()
547+
s1 = splits.SplitDict("ds_name")
545548
s1.add(tfds.core.SplitInfo(name="train", num_shards=10))
546549
s1.add(tfds.core.SplitInfo(name="test", num_shards=3))
547550

548-
s2 = splits.SplitDict()
551+
s2 = splits.SplitDict("ds_name")
549552
s2.add(tfds.core.SplitInfo(name="train", num_shards=10))
550553
s2.add(tfds.core.SplitInfo(name="test", num_shards=3))
551554

552-
s3 = splits.SplitDict()
555+
s3 = splits.SplitDict("ds_name")
553556
s3.add(tfds.core.SplitInfo(name="train", num_shards=10))
554557
s3.add(tfds.core.SplitInfo(name="test", num_shards=3))
555558
s3.add(tfds.core.SplitInfo(name="valid", num_shards=0))
556559

557-
s4 = splits.SplitDict()
560+
s4 = splits.SplitDict("ds_name")
558561
s4.add(tfds.core.SplitInfo(name="train", num_shards=11))
559562
s4.add(tfds.core.SplitInfo(name="test", num_shards=3))
560563

@@ -564,10 +567,10 @@ def test_check_splits_equals(self):
564567
self.assertFalse(splits.check_splits_equals(s1, s4)) # Nb of shards !=
565568

566569
def test_split_overwrite(self):
567-
s1 = splits.SplitDict()
570+
s1 = splits.SplitDict("ds_name")
568571
s1.add(tfds.core.SplitInfo(name="train", shard_lengths=[15]))
569572

570-
s2 = splits.SplitDict()
573+
s2 = splits.SplitDict("ds_name")
571574
s2.add(tfds.core.SplitInfo(name="train", shard_lengths=[15]))
572575

573576
self.assertTrue(splits.check_splits_equals(s1, s2))
@@ -579,5 +582,45 @@ def test_split_overwrite(self):
579582
self.assertFalse(splits.check_splits_equals(s1, s2))
580583

581584

585+
class SplitsSubsplitTest(testing.TestCase):
586+
587+
@classmethod
588+
def setUpClass(cls):
589+
super(SplitsSubsplitTest, cls).setUpClass()
590+
cls._builder = testing.DummyDatasetSharedGenerator(
591+
data_dir=testing.make_tmp_dir())
592+
cls._builder.download_and_prepare()
593+
594+
def test_sub_split_num_examples(self):
595+
s = self._builder.info.splits
596+
self.assertEqual(s["train[75%:]"].num_examples, 5)
597+
self.assertEqual(s["train[:75%]"].num_examples, 15)
598+
self.assertEqual(
599+
s["train"].num_examples,
600+
s["train[75%:]"].num_examples + s["train[:75%]"].num_examples,
601+
)
602+
603+
self.assertEqual(s["test[75%:]"].num_examples, 2)
604+
self.assertEqual(s["test[:75%]"].num_examples, 8)
605+
self.assertEqual(
606+
s["test"].num_examples,
607+
s["test[75%:]"].num_examples + s["test[:75%]"].num_examples,
608+
)
609+
610+
def test_sub_split_file_instructions(self):
611+
fi = self._builder.info.splits["train[75%:]"].file_instructions
612+
self.assertEqual(fi, [{
613+
"filename":
614+
"dummy_dataset_shared_generator-train.tfrecord-00000-of-00001",
615+
"skip": 15,
616+
"take": -1,
617+
}])
618+
619+
def test_sub_split_wrong_key(self):
620+
with self.assertRaisesWithPredicateMatch(
621+
ValueError, "Unknown split \"unknown\""):
622+
_ = self._builder.info.splits["unknown"]
623+
624+
582625
if __name__ == "__main__":
583626
testing.test_main()

0 commit comments

Comments
 (0)