Skip to content

Commit 02f8cca

Browse files
Conchylicultorcopybara-github
authored andcommitted
Add a decode kwargs to as_dataset to customize decoding
PiperOrigin-RevId: 257476585
1 parent 7f17a1d commit 02f8cca

File tree

10 files changed

+480
-42
lines changed

10 files changed

+480
-42
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,10 @@ def as_dataset(self,
312312
split=None,
313313
batch_size=None,
314314
shuffle_files=None,
315+
decoders=None,
315316
as_supervised=False,
316317
in_memory=None):
318+
# pylint: disable=line-too-long
317319
"""Constructs a `tf.data.Dataset`.
318320
319321
Callers must pass arguments as keyword arguments.
@@ -330,6 +332,9 @@ def as_dataset(self,
330332
`tf.data.Dataset`.
331333
shuffle_files: `bool`, whether to shuffle the input files.
332334
Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
335+
decoders: Nested dict of `Decoder` objects which allow to customize the
336+
decoding. The structure should match the feature structure, but only
337+
customized feature keys need to be present.
333338
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
334339
will have a 2-tuple structure `(input, label)` according to
335340
`builder.info.supervised_keys`. If `False`, the default,
@@ -347,6 +352,7 @@ def as_dataset(self,
347352
If `batch_size` is -1, will return feature dictionaries containing
348353
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
349354
"""
355+
# pylint: enable=line-too-long
350356
logging.info("Constructing tf.data.Dataset for split %s, from %s",
351357
split, self._data_dir)
352358
if not tf.io.gfile.exists(self._data_dir):
@@ -365,14 +371,21 @@ def as_dataset(self,
365371
self._build_single_dataset,
366372
shuffle_files=shuffle_files,
367373
batch_size=batch_size,
374+
decoders=decoders,
368375
as_supervised=as_supervised,
369376
in_memory=in_memory,
370377
)
371378
datasets = utils.map_nested(build_single_dataset, split, map_tuple=True)
372379
return datasets
373380

374-
def _build_single_dataset(self, split, shuffle_files, batch_size,
375-
as_supervised, in_memory):
381+
def _build_single_dataset(
382+
self,
383+
split,
384+
shuffle_files,
385+
batch_size,
386+
decoders,
387+
as_supervised,
388+
in_memory):
376389
"""as_dataset for a single split."""
377390
if isinstance(split, six.string_types):
378391
split = splits_lib.Split(split)
@@ -424,13 +437,15 @@ def _build_single_dataset(self, split, shuffle_files, batch_size,
424437
# If using in_memory, escape all device contexts so we can load the data
425438
# with a local Session.
426439
with tf.device(None):
427-
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
440+
dataset = self._as_dataset(
441+
split=split, shuffle_files=shuffle_files, decoders=decoders)
428442
# Use padded_batch so that features with unknown shape are supported.
429443
dataset = dataset.padded_batch(full_bs, dataset.output_shapes)
430444
dataset = tf.data.Dataset.from_tensor_slices(
431445
next(dataset_utils.as_numpy(dataset)))
432446
else:
433-
dataset = self._as_dataset(split=split, shuffle_files=shuffle_files)
447+
dataset = self._as_dataset(
448+
split=split, shuffle_files=shuffle_files, decoders=decoders)
434449

435450
if batch_size:
436451
# Use padded_batch so that features with unknown shape are supported.
@@ -567,16 +582,18 @@ def _download_and_prepare(self, dl_manager, download_config=None):
567582
raise NotImplementedError
568583

569584
@abc.abstractmethod
570-
def _as_dataset(self, split, shuffle_files=None):
585+
def _as_dataset(self, split, decoders=None, shuffle_files=None):
571586
"""Constructs a `tf.data.Dataset`.
572587
573588
This is the internal implementation to overwrite called when user calls
574589
`as_dataset`. It should read the pre-processed datasets files and generate
575590
the `tf.data.Dataset` object.
576591
577592
Args:
578-
split (`tfds.Split`): which subset of the data to read.
579-
shuffle_files (bool): whether to shuffle the input files. Optional,
593+
split: `tfds.Split` which subset of the data to read.
594+
decoders: Nested structure of `Decoder` object to customize the dataset
595+
decoding.
596+
shuffle_files: `bool`, whether to shuffle the input files. Optional,
580597
defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
581598
582599
Returns:
@@ -759,7 +776,12 @@ def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
759776
# Update the info object with the splits.
760777
self.info.update_splits_if_different(split_dict)
761778

762-
def _as_dataset(self, split=splits_lib.Split.TRAIN, shuffle_files=False):
779+
def _as_dataset(
780+
self,
781+
split=splits_lib.Split.TRAIN,
782+
decoders=None,
783+
shuffle_files=False):
784+
763785
if self.version.implements(utils.Experiment.S3):
764786
dataset = self._tfrecords_reader.read(
765787
self.name, split, self.info.splits.values(), shuffle_files)
@@ -780,9 +802,11 @@ def _as_dataset(self, split=splits_lib.Split.TRAIN, shuffle_files=False):
780802
dataset_from_file_fn=self._file_format_adapter.dataset_from_filename,
781803
shuffle_files=shuffle_files,
782804
)
805+
806+
decode_fn = functools.partial(
807+
self.info.features.decode_example, decoders=decoders)
783808
dataset = dataset.map(
784-
self.info.features.decode_example,
785-
num_parallel_calls=tf.data.experimental.AUTOTUNE)
809+
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
786810
return dataset
787811

788812
def _slice_split_info_to_instruction_dicts(self, list_sliced_split_info):
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Decoder public API.
17+
18+
"""
19+
20+
from tensorflow_datasets.core.decode.base import Decoder
21+
from tensorflow_datasets.core.decode.base import make_decoder
22+
from tensorflow_datasets.core.decode.base import SkipDecoding
23+
24+
__all__ = [
25+
'Decoder',
26+
'make_decoder',
27+
'SkipDecoding',
28+
]
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Base decoders.
17+
"""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import abc
24+
import functools
25+
26+
import six
27+
from tensorflow_datasets.core import api_utils
28+
from tensorflow_datasets.core.utils import py_utils
29+
30+
31+
@six.add_metaclass(abc.ABCMeta)
32+
class Decoder(object):
33+
"""Base decoder object.
34+
35+
`tfds.decode.Decoder` allows for overriding the default decoding by
36+
implementing a subclass, or skipping it entirely with
37+
`tfds.decode.SkipDecoding`.
38+
39+
Instead of subclassing, you can also create a `Decoder` from a function
40+
with the `tfds.decode.make_decoder` decorator.
41+
42+
All decoders must derive from this base class. The implementation can
43+
access the `self.feature` property which will correspond to the
44+
`FeatureConnector` to which this decoder is applied.
45+
46+
To implement a decoder, the main method to override is `decode_example`,
47+
which takes the serialized feature as input and returns the decoded feature.
48+
49+
If `decode_example` changes the output dtype, you must also override
50+
the `dtype` property. This enables compatibility with
51+
`tfds.features.Sequence`.
52+
"""
53+
54+
def __init__(self):
55+
self.feature = None
56+
57+
@api_utils.disallow_positional_args
58+
def setup(self, feature):
59+
"""Transformation contructor.
60+
61+
The initialization of decode object is deferred because the objects only
62+
know the builder/features on which it is used after it has been
63+
constructed, the initialization is done in this function.
64+
65+
Args:
66+
feature: `tfds.features.FeatureConnector`, the feature to which is applied
67+
this transformation.
68+
69+
"""
70+
self.feature = feature
71+
72+
@property
73+
def dtype(self):
74+
"""Returns the `dtype` after decoding."""
75+
tensor_info = self.feature.get_tensor_info()
76+
return py_utils.map_nested(lambda t: t.dtype, tensor_info)
77+
78+
@abc.abstractmethod
79+
def decode_example(self, serialized_example):
80+
"""Decode the example feature field (eg: image).
81+
82+
Args:
83+
serialized_example: `tf.Tensor` as decoded, the dtype/shape should be
84+
identical to `feature.get_serialized_info()`
85+
86+
Returns:
87+
example: Decoded example.
88+
"""
89+
raise NotImplementedError('Abstract class')
90+
91+
92+
class SkipDecoding(Decoder):
93+
"""Transformation which skip the decoding entirelly.
94+
95+
Example of usage:
96+
97+
```python
98+
ds = ds.load(
99+
'imagenet2012',
100+
split='train',
101+
decoders={
102+
'image': tfds.decode.SkipDecoding(),
103+
}
104+
)
105+
106+
for ex in ds.take(1):
107+
assert ex['image'].dtype == tf.string
108+
```
109+
"""
110+
111+
@property
112+
def dtype(self):
113+
tensor_info = self.feature.get_serialized_info()
114+
return py_utils.map_nested(lambda t: t.dtype, tensor_info)
115+
116+
def decode_example(self, serialized_example):
117+
"""Forward the serialized feature field."""
118+
return serialized_example
119+
120+
121+
class DecoderFn(Decoder):
122+
"""Decoder created by `tfds.decoder.make_decoder` decorator."""
123+
124+
def __init__(self, fn, output_dtype, *args, **kwargs):
125+
super(DecoderFn, self).__init__()
126+
self._fn = fn
127+
self._output_dtype = output_dtype
128+
self._args = args
129+
self._kwargs = kwargs
130+
131+
@property
132+
def dtype(self):
133+
if self._output_dtype is None:
134+
return super(DecoderFn, self).dtype
135+
else:
136+
return self._output_dtype
137+
138+
def decode_example(self, serialized_example):
139+
"""Decode the example using the function."""
140+
return self._fn(
141+
serialized_example, self.feature, *self._args, **self._kwargs)
142+
143+
144+
def make_decoder(output_dtype=None):
145+
"""Decorator to create a decoder.
146+
147+
The decorated function should have the signature `(example, feature, *args,
148+
**kwargs) -> decoded_example`.
149+
150+
* `example`: Serialized example before decoding
151+
* `feature`: `FeatureConnector` associated with the example
152+
* `*args, **kwargs`: Optional additional kwargs forwarded to the function
153+
154+
Example:
155+
156+
```
157+
@tfds.decode.make_decoder(output_dtype=tf.string)
158+
def no_op_decoder(example, feature):
159+
\"\"\"Decoder simply decoding feature normally.\"\"\"
160+
return feature.decode_example(example)
161+
162+
tfds.load('mnist', split='train', decoder: {
163+
'image': no_op_decoder(),
164+
})
165+
```
166+
167+
Args:
168+
output_dtype: The output dtype after decoding. Required only if the decoded
169+
example has a different type than the `FeatureConnector.dtype` and is
170+
used to decode features inside sequences (ex: videos)
171+
172+
Returns:
173+
The decoder object
174+
""" # pylint: disable=g-docstring-has-escape
175+
176+
def decorator(fn):
177+
178+
@functools.wraps(fn)
179+
def decorated(*args, **kwargs):
180+
return DecoderFn(fn, output_dtype, *args, **kwargs)
181+
return decorated
182+
183+
return decorator

0 commit comments

Comments
 (0)