Skip to content

Commit b6f9b00

Browse files
Conchylicultorcopybara-github
authored andcommitted
Expose auto-cache default value to the documentation
PiperOrigin-RevId: 296013946
1 parent f734af4 commit b6f9b00

File tree

3 files changed

+58
-11
lines changed

3 files changed

+58
-11
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def _build_single_dataset(
541541
read_config=read_config,
542542
)
543543
# Auto-cache small datasets which are small enough to fit in memory.
544-
# TODO(tfds): Should expose auto-caching default value in the dataset doc.
545544
if self._should_cache_ds(
546545
split=split,
547546
shuffle_files=shuffle_files,

tensorflow_datasets/scripts/document_datasets_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""Test of `document_datasets.py`."""
1617
from __future__ import absolute_import
1718
from __future__ import division
1819
from __future__ import print_function

tensorflow_datasets/scripts/templates/dataset.mako.md

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ Displayed in https://www.tensorflow.org/datasets/catalog/.
77

88
import collections
99
import tensorflow_datasets as tfds
10-
from tensorflow_datasets.core.utils.py_utils import get_class_path
11-
from tensorflow_datasets.core.utils.py_utils import get_class_url
10+
from tensorflow_datasets.core.utils import py_utils
1211

1312
%>
1413

@@ -36,16 +35,14 @@ ${builder.info.description}
3635
</%def>
3736

3837
<%def name="display_source(builder)">\
39-
* **Source code**: [`${get_class_path(builder)}`](${get_class_url(builder)})
38+
* **Source code**:
39+
[`${py_utils.get_class_path(builder)}`](${py_utils.get_class_url(builder)})
4040
</%def>
4141

4242
<%def name="display_versions(builder)">\
4343
<%
4444
def list_versions(builder):
45-
# List all available versions
46-
# Sort them in order
47-
# Get the default version
48-
for v in builder.versions:
45+
for v in builder.versions: # List all available versions (in default order)
4946
if v == builder.version: # Highlight the default version
5047
version_name = '**`{}`** (default)'.format(str(v))
5148
else:
@@ -63,6 +60,54 @@ def list_versions(builder):
6360
* **Dataset size**: `${tfds.units.size_str(builder.info.dataset_size)}`
6461
</%def>
6562

63+
<%
64+
def build_autocached_info(builder):
65+
"""Returns the auto-cache information string."""
66+
always_cached = set()
67+
never_cached = set()
68+
unshuffle_cached = set()
69+
for split_name in builder.info.splits.keys():
70+
split_name = str(split_name)
71+
cache_shuffled = builder._should_cache_ds(
72+
split_name, shuffle_files=True, read_config=tfds.ReadConfig())
73+
cache_unshuffled = builder._should_cache_ds(
74+
split_name, shuffle_files=False, read_config=tfds.ReadConfig())
75+
76+
if cache_shuffled == cache_unshuffled == True:
77+
always_cached.add(split_name)
78+
elif cache_shuffled == cache_unshuffled == False:
79+
never_cached.add(split_name)
80+
else: # Dataset is only cached when shuffled_files is False
81+
assert not cache_shuffled and cache_unshuffled
82+
unshuffle_cached.add(split_name)
83+
84+
85+
if len(always_cached) == len(builder.info.splits.keys()):
86+
autocached_info = 'Yes' # All splits are auto-cached.
87+
elif len(never_cached) == len(builder.info.splits.keys()):
88+
autocached_info = 'No' # Splits never auto-cached.
89+
else: # Some splits cached, some not.
90+
autocached_info_parts = []
91+
if always_cached:
92+
split_names_str = ', '.join(always_cached)
93+
autocached_info_parts.append('Yes ({})'.format(split_names_str))
94+
if never_cached:
95+
split_names_str = ', '.join(never_cached)
96+
autocached_info_parts.append('No ({})'.format(split_names_str))
97+
if unshuffle_cached:
98+
split_names_str = ', '.join(unshuffle_cached)
99+
autocached_info_parts.append(
100+
'Only when `shuffle_files=False` ({})'.format(split_names_str))
101+
autocached_info = ', '.join(autocached_info_parts)
102+
return autocached_info
103+
104+
%>
105+
<%def name="display_autocache(builder)">\
106+
* **Auto-cached**
107+
([documentation](https://www.tensorflow.org/datasets/performances#auto-caching)):
108+
${build_autocached_info(builder)}
109+
</%def>
110+
66111
<%def name="display_manual(builder)">\
67112
% if builder.MANUAL_DOWNLOAD_INSTRUCTIONS:
68113
* **Manual download instructions**: This dataset requires you to download the
@@ -129,6 +174,7 @@ def get_versions(builder):
129174
return tuple((str(v), v.description) for v in builder.versions)
130175
def get_size(builder): (builder.info.download_size, builder.info.dataset_size)
131176
def get_manual(builder): builder.MANUAL_DOWNLOAD_INSTRUCTIONS
177+
def get_autocache(builder): build_autocached_info(builder)
132178
def get_splits(builder):
133179
return tuple(
134180
(str(s.name), int(s.num_examples)) for s in builder.info.splits.values()
@@ -145,6 +191,7 @@ all_sections = [
145191
Section(get_versions, display_versions),
146192
Section(get_size, display_size),
147193
Section(get_manual, display_manual),
194+
Section(get_autocache, display_autocache),
148195
Section(get_splits, display_splits),
149196
Section(get_features, display_features),
150197
Section(get_supervised, display_supervised),
@@ -161,7 +208,7 @@ ${section.make(builder)}\
161208
% endfor
162209
</%def>
163210

164-
## --------------------------- Builder builder ---------------------------
211+
## --------------------------- Builder configs ---------------------------
165212

166213
<%def name="display_all_builders(builders)">\
167214
<%
@@ -181,9 +228,9 @@ ${display_builder(next(iter(builders)), common_sections)}
181228

182229
% for i, builder in enumerate(builders):
183230
<%
184-
header_suffix = ' (default config)' if i == 0 else ''
231+
header_suffix = '(default config)' if i == 0 else ''
185232
%>\
186-
${'##'} ${builder.name}/${builder.builder_config.name}${header_suffix}
233+
${'##'} ${builder.name}/${builder.builder_config.name} ${header_suffix}
187234

188235
${display_builder(builder, unique_sections)}
189236
% endfor

0 commit comments

Comments
 (0)