Skip to content

Commit 51a66d4

Browse files
Conchylicultorcopybara-github
authored andcommitted
Add doc on customizing decoding
PiperOrigin-RevId: 258624307
1 parent 57dbf41 commit 51a66d4

File tree

7 files changed

+170
-4
lines changed

7 files changed

+170
-4
lines changed

docs/_book.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ upper_tabs:
2323
path: /datasets/splits
2424
- title: Add a dataset
2525
path: /datasets/add_dataset
26+
- title: Feature decoding
27+
path: /datasets/decode
2628
- title: Add huge datasets
2729
path: /datasets/beam_datasets
2830
- title: Store your dataset on GCS

docs/decode.md

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Customizing feature decoding
2+
3+
* [Usage examples](#usage-examples)
4+
* [Skipping the image decoding](#skipping-the-image-decoding)
5+
* [Filter/shuffle dataset before images get decoded](#filtershuffle-dataset-before-images-get-decoded)
6+
* [Cropping and decoding at the same time](#cropping-and-decoding-at-the-same-time)
7+
* [Customizing video decoding](#customizing-video-decoding)
8+
9+
The `tfds.decode` API allows you override the default feature decoding. The main
10+
use case is to skip the image decoding for better performance.
11+
12+
Warning: This API gives you access to the low-level `tf.train.Example` format on
13+
disk (as defined by the `FeatureConnector`). This API is targeted towards
14+
advanced users who want better read performance with images.
15+
16+
## Usage examples
17+
18+
### Skipping the image decoding
19+
20+
To keep full control over the decoding pipeline, or to apply a filter before the
21+
images get decoded (for better performance), you can skip the image decoding
22+
entirely. This works with both `tfds.features.Image` and `tfds.features.Video`.
23+
24+
```python
25+
ds = tfds.load('imagenet2012', split='train', decoders={
26+
'image': tfds.decode.SkipDecoding(),
27+
})
28+
29+
for example in ds.take(1):
30+
assert example['image'].dtype == tf.string # Images are not decoded
31+
```
32+
33+
### Filter/shuffle dataset before images get decoded
34+
35+
Similarly to the previous example, you can use `tfds.decode.SkipDecoding()` to
36+
insert additional `tf.data` pipeline customization before decoding the image.
37+
That way the filtered images won't be decoded and you can use a bigger shuffle
38+
buffer.
39+
40+
```python
41+
# Load the base dataset without decoding
42+
ds, ds_info = tfds.load(
43+
'imagenet2012',
44+
split='train',
45+
decoders={
46+
'image': tfds.decode.SkipDecoding(), # Image won't be decoded here
47+
},
48+
as_supervised=True,
49+
with_info=True,
50+
)
51+
# Apply filter and shuffle
52+
ds = ds.filter(lambda image, label: label != 10)
53+
ds = ds.shuffle(10000)
54+
# Then decode with ds_info.features['image']
55+
ds = ds.map(
56+
lambda image, label: ds_info.features['image'].decode_example(image), label)
57+
58+
```
59+
60+
### Cropping and decoding at the same time
61+
62+
To override the default `tf.io.decode_image` operation, you can create a new
63+
`tfds.decode.Decoder` object using the `tfds.decode.make_decoder()` decorator.
64+
65+
```python
66+
@tfds.decode.make_decoder()
67+
def decode_example(serialized_image, feature):
68+
crop_y, crop_x, crop_height, crop_width = 10, 10, 64, 64
69+
return tf.image.decode_and_crop_jpeg(
70+
serialized_image,
71+
[crop_y, crop_x, crop_height, crop_width],
72+
channels=feature.feature.shape[-1],
73+
)
74+
75+
ds = tfds.load('imagenet2012', split='train', decoders={
76+
# With video, decoders are applied to individual frames
77+
'image': decode_example(),
78+
})
79+
```
80+
81+
Which is equivalent to:
82+
83+
```python
84+
def decode_example(serialized_image, feature):
85+
crop_y, crop_x, crop_height, crop_width = 10, 10, 64, 64
86+
return tf.image.decode_and_crop_jpeg(
87+
serialized_image,
88+
[crop_y, crop_x, crop_height, crop_width],
89+
channels=feature.shape[-1],
90+
)
91+
92+
ds, ds_info = tfds.load(
93+
'imagenet2012',
94+
split='train',
95+
with_info=True,
96+
decoders={
97+
'image': tfds.decode.SkipDecoding(), # Skip frame decoding
98+
},
99+
)
100+
ds = ds.map(functools.partial(decode_example, feature=ds_info.features['image']))
101+
```
102+
103+
### Customizing video decoding
104+
105+
Video are `Sequence(Image())`. When applying custom decoders, they will be
106+
applied to individual frames. This mean decoders for images are automatically
107+
compatible with video.
108+
109+
```python
110+
@tfds.decode.make_decoder()
111+
def decode_example(serialized_image, feature):
112+
crop_y, crop_x, crop_height, crop_width = 10, 10, 64, 64
113+
return tf.image.decode_and_crop_jpeg(
114+
serialized_image,
115+
[crop_y, crop_x, crop_height, crop_width],
116+
channels=feature.feature.shape[-1],
117+
)
118+
119+
ds = tfds.load('ucf101', split='train', decoders={
120+
# With video, decoders are applied to individual frames
121+
'video': decode_example(),
122+
})
123+
```
124+
125+
Which is equivalent to:
126+
127+
```python
128+
def decode_frame(serialized_image):
129+
"""Decodes a single frame."""
130+
crop_y, crop_x, crop_height, crop_width = 10, 10, 64, 64
131+
return tf.image.decode_and_crop_jpeg(
132+
serialized_image,
133+
[crop_y, crop_x, crop_height, crop_width],
134+
channels=ds_info.features['video'].shape[-1],
135+
)
136+
137+
138+
def decode_video(example):
139+
"""Decodes all individual frames of the video."""
140+
video = example['video']
141+
video = tf.map_fn(
142+
decode_frame,
143+
video,
144+
dtype=ds_info.features['video'].dtype,
145+
parallel_iterations=10,
146+
back_prop=False,
147+
)
148+
example['video'] = video
149+
return example
150+
151+
152+
ds, ds_info = tfds.load('ucf101', split='train', with_info=True, decoders={
153+
'video': tfds.decode.SkipDecoding(), # Skip frame decoding
154+
})
155+
ds = ds.map(decode_video) # Decode the video
156+
```

docs/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
* It is now possible to add arbitrary metadata to `tfds.core.DatasetInfo`
1717
which will be stored/restored with the dataset. See `tfds.core.Metadata`.
1818
* Better proxy support, possibility to add certificate
19+
* Add `decoders` kwargs to override the default feature decoding
20+
([guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)).

tensorflow_datasets/core/dataset_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ def as_dataset(self,
374374
Defaults to `True` if `split == tfds.Split.TRAIN` and `False` otherwise.
375375
decoders: Nested dict of `Decoder` objects which allow to customize the
376376
decoding. The structure should match the feature structure, but only
377-
customized feature keys need to be present.
377+
customized feature keys need to be present. See
378+
[the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
379+
for more info.
378380
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
379381
will have a 2-tuple structure `(input, label)` according to
380382
`builder.info.supervised_keys`. If `False`, the default,

tensorflow_datasets/core/decode/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def no_op_decoder(example, feature):
159159
\"\"\"Decoder simply decoding feature normally.\"\"\"
160160
return feature.decode_example(example)
161161
162-
tfds.load('mnist', split='train', decoder: {
162+
tfds.load('mnist', split='train', decoders: {
163163
'image': no_op_decoder(),
164164
})
165165
```

tensorflow_datasets/core/features/top_level_feature.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def decode_example(self, serialized_example, decoders=None):
5454
serialized_example: Nested `dict` of `tf.Tensor`
5555
decoders: Nested dict of `Decoder` objects which allow to customize the
5656
decoding. The structure should match the feature structure, but only
57-
customized feature keys need to be present.
57+
customized feature keys need to be present. See
58+
[the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
59+
for more info.
5860
5961
Returns:
6062
example: Nested `dict` containing the decoded nested examples.

tensorflow_datasets/core/registered.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def load(name,
250250
features.
251251
decoders: Nested dict of `Decoder` objects which allow to customize the
252252
decoding. The structure should match the feature structure, but only
253-
customized feature keys need to be present.
253+
customized feature keys need to be present. See
254+
[the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
255+
for more info.
254256
with_info: `bool`, if True, tfds.load will return the tuple
255257
(tf.data.Dataset, tfds.core.DatasetInfo) containing the info associated
256258
with the builder.

0 commit comments

Comments
 (0)