Skip to content

Commit e1c9535

Browse files
Conchylicultorcopybara-github
authored andcommitted
Expose file_instructions to SplitInfo
PiperOrigin-RevId: 292600969
1 parent a708d50 commit e1c9535

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ def __repr__(self):
4545
num_examples = self.num_examples or "unknown"
4646
return "<tfds.core.SplitInfo num_examples=%s>" % str(num_examples)
4747

48+
@property
49+
def file_instructions(self):
50+
"""Returns the list of dict(filename, take, skip)."""
51+
# `self._dataset_name` is assigned in `SplitDict.add()`.
52+
instructions = tfrecords_reader.make_file_instructions(
53+
name=self._dataset_name,
54+
split_infos=[self],
55+
instruction=str(self.name),
56+
)
57+
return instructions.file_instructions
58+
4859

4960
class SubSplitInfo(object):
5061
"""Wrapper around a sub split info.
@@ -583,7 +594,10 @@ def add(self, split_info):
583594
"""Add the split info."""
584595
if split_info.name in self:
585596
raise ValueError("Split {} already present".format(split_info.name))
586-
# TODO(epot): Make sure this works with Named splits correctly.
597+
# Forward the dataset name required to build file instructions:
598+
# info.splits['train'].file_instructions
599+
# Use `object.__setattr__`, because ProtoCls forbid new fields assignement.
600+
object.__setattr__(split_info, "_dataset_name", self._dataset_name)
587601
super(SplitDict, self).__setitem__(split_info.name, split_info)
588602

589603
@classmethod

tensorflow_datasets/core/splits_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,15 @@ def test_sub_split_file_instructions(self):
616616
"take": -1,
617617
}])
618618

619+
def test_split_file_instructions(self):
620+
fi = self._builder.info.splits["train"].file_instructions
621+
self.assertEqual(fi, [{
622+
"filename":
623+
"dummy_dataset_shared_generator-train.tfrecord-00000-of-00001",
624+
"skip": 0,
625+
"take": -1,
626+
}])
627+
619628
def test_sub_split_wrong_key(self):
620629
with self.assertRaisesWithPredicateMatch(
621630
ValueError, "Unknown split \"unknown\""):

0 commit comments

Comments
 (0)