Skip to content

Commit 5fd860b

Browse files
pierrot0copybara-github
authored andcommitted
move serialization / parsing functions out of file_adapter to distinct modules.
PiperOrigin-RevId: 251819216
1 parent 37b1447 commit 5fd860b

File tree

4 files changed

+303
-223
lines changed

4 files changed

+303
-223
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""To deserialize bytes (Example) to tf.Example."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import tensorflow as tf
23+
from tensorflow_datasets.core import utils
24+
25+
26+
class ExampleParser(object):
27+
"""To parse Examples."""
28+
29+
def __init__(self, example_specs):
30+
self._example_specs = example_specs
31+
self._flat_example_specs = utils.flatten_nest_dict(self._example_specs)
32+
33+
def _build_feature_specs(self):
34+
"""Returns the `tf.train.Example` feature specification.
35+
36+
Returns:
37+
The `dict` of `tf.io.FixedLenFeature`, `tf.io.VarLenFeature`, ...
38+
"""
39+
# Convert individual fields into tf.train.Example compatible format
40+
def build_single_spec(k, v):
41+
with utils.try_reraise(
42+
"Specification error for feature {} ({}): ".format(k, v)):
43+
return _to_tf_example_spec(v)
44+
45+
return {
46+
k: build_single_spec(k, v) for k, v in self._flat_example_specs.items()
47+
}
48+
49+
def parse_example(self, serialized_example):
50+
"""Deserialize a single `tf.train.Example` proto.
51+
52+
Usage:
53+
```
54+
ds = tf.data.TFRecordDataset(filepath)
55+
ds = ds.map(file_adapter.parse_example)
56+
```
57+
58+
Args:
59+
serialized_example: `tf.Tensor`, the `tf.string` tensor containing the
60+
serialized proto to decode.
61+
62+
Returns:
63+
example: A nested `dict` of `tf.Tensor` values. The structure and tensors
64+
shape/dtype match the `example_specs` provided at construction.
65+
"""
66+
example = tf.io.parse_single_example(
67+
serialized=serialized_example,
68+
features=self._build_feature_specs(),
69+
)
70+
example = {
71+
k: _deserialize_single_field(example_data, tensor_info)
72+
for k, (example_data, tensor_info)
73+
in utils.zip_dict(example, self._flat_example_specs)
74+
}
75+
# Reconstruct all nesting
76+
example = utils.pack_as_nest_dict(example, self._example_specs)
77+
return example
78+
79+
80+
def _deserialize_single_field(example_data, tensor_info):
81+
"""Reconstruct the serialized field."""
82+
83+
# Restore shape if possible. TF Example flattened it.
84+
if tensor_info.shape.count(None) < 2:
85+
shape = [-1 if i is None else i for i in tensor_info.shape]
86+
example_data = tf.reshape(example_data, shape)
87+
88+
# Restore dtype
89+
if example_data.dtype != tensor_info.dtype:
90+
example_data = tf.dtypes.cast(example_data, tensor_info.dtype)
91+
return example_data
92+
93+
94+
def _to_tf_example_spec(tensor_info):
95+
"""Convert a `TensorInfo` into a feature proto object."""
96+
# Convert the dtype
97+
98+
# TODO(b/119937875): TF Examples proto only support int64, float32 and string
99+
# This create limitation like float64 downsampled to float32, bool converted
100+
# to int64 which is space ineficient, no support for complexes or quantized
101+
# It seems quite space inefficient to convert bool to int64
102+
if tensor_info.dtype.is_integer or tensor_info.dtype.is_bool:
103+
dtype = tf.int64
104+
elif tensor_info.dtype.is_floating:
105+
dtype = tf.float32
106+
elif tensor_info.dtype == tf.string:
107+
dtype = tf.string
108+
else:
109+
# TFRecord only support 3 types
110+
raise NotImplementedError(
111+
"Serialization not implemented for dtype {}".format(tensor_info))
112+
113+
# Convert the shape
114+
115+
# Select the feature proto type in function of the unknown shape
116+
if all(s is not None for s in tensor_info.shape):
117+
return tf.io.FixedLenFeature( # All shaped defined
118+
shape=tensor_info.shape,
119+
dtype=dtype,
120+
default_value=tensor_info.default_value,
121+
)
122+
elif (tensor_info.shape.count(None) == 1 and tensor_info.shape[0] is None):
123+
return tf.io.FixedLenSequenceFeature( # First shape undefined
124+
shape=tensor_info.shape[1:],
125+
dtype=dtype,
126+
allow_missing=True,
127+
default_value=tensor_info.default_value,
128+
)
129+
else:
130+
raise NotImplementedError(
131+
"Tensor with a unknown dimension not at the first position not "
132+
"supported: {}".format(tensor_info))
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""To serialize Dict or sequence to Example."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import numpy as np
23+
import six
24+
import tensorflow as tf
25+
26+
from tensorflow_datasets.core import utils
27+
28+
29+
class ExampleSerializer(object):
30+
"""To serialize examples."""
31+
32+
def __init__(self, example_specs):
33+
"""Constructor.
34+
35+
Args:
36+
example_specs: Nested `dict` of `tfds.features.TensorInfo`, corresponding
37+
to the structure of data to write/read.
38+
"""
39+
self._example_specs = example_specs
40+
self._flat_example_specs = utils.flatten_nest_dict(self._example_specs)
41+
42+
def serialize_example(self, example):
43+
"""Serialize the given example.
44+
45+
Args:
46+
example: Nested `dict` containing the input to serialize. The input
47+
structure and values dtype/shape must match the `example_specs`
48+
provided at construction.
49+
50+
Returns:
51+
serialize_proto: `str`, the serialized `tf.train.Example` proto
52+
"""
53+
example = utils.flatten_nest_dict(example)
54+
example = _dict_to_tf_example(example, self._flat_example_specs)
55+
return example.SerializeToString()
56+
57+
58+
def _dict_to_tf_example(example_dict, tensor_info_dict=None):
59+
"""Builds tf.train.Example from (string -> int/float/str list) dictionary.
60+
61+
Args:
62+
example_dict: `dict`, dict of values, tensor,...
63+
tensor_info_dict: `dict` of `tfds.feature.TensorInfo` If given, perform
64+
additional checks on the example dict (check dtype, shape, number of
65+
fields...)
66+
"""
67+
def serialize_single_field(k, example_data, tensor_info):
68+
with utils.try_reraise(
69+
"Error while serializing feature {} ({}): ".format(k, tensor_info)):
70+
return _item_to_tf_feature(example_data, tensor_info)
71+
72+
if tensor_info_dict:
73+
example_dict = {
74+
k: serialize_single_field(k, example_data, tensor_info)
75+
for k, (example_data, tensor_info)
76+
in utils.zip_dict(example_dict, tensor_info_dict)
77+
}
78+
else:
79+
example_dict = {
80+
k: serialize_single_field(k, example_data, None)
81+
for k, example_data in example_dict.items()
82+
}
83+
84+
return tf.train.Example(features=tf.train.Features(feature=example_dict))
85+
86+
87+
def _is_string(item):
88+
"""Check if the object contains string or bytes."""
89+
if isinstance(item, (six.binary_type, six.string_types)):
90+
return True
91+
elif (isinstance(item, (tuple, list)) and
92+
all(isinstance(x, (six.binary_type, six.string_types)) for x in item)):
93+
return True
94+
elif (isinstance(item, np.ndarray) and # binary or unicode
95+
(item.dtype.kind in ("U", "S") or item.dtype == object)):
96+
return True
97+
return False
98+
99+
100+
def _item_to_tf_feature(item, tensor_info=None):
101+
"""Single item to a tf.train.Feature."""
102+
v = item
103+
if not tensor_info and isinstance(v, (list, tuple)) and not v:
104+
raise ValueError(
105+
"Received an empty list value, so is unable to infer the "
106+
"feature type to record. To support empty value, the corresponding "
107+
"FeatureConnector should return a numpy array with the correct dtype "
108+
"instead of a Python list."
109+
)
110+
111+
# Handle strings/bytes first
112+
is_string = _is_string(v)
113+
114+
if tensor_info:
115+
np_dtype = np.dtype(tensor_info.dtype.as_numpy_dtype)
116+
elif is_string:
117+
np_dtype = object # Avoid truncating trailing '\x00' when converting to np
118+
else:
119+
np_dtype = None
120+
121+
v = np.array(v, dtype=np_dtype)
122+
123+
# Check that the shape is expected
124+
if tensor_info:
125+
utils.assert_shape_match(v.shape, tensor_info.shape)
126+
if tensor_info.dtype == tf.string and not is_string:
127+
raise ValueError(
128+
"Unsuported value: {}\nCould not convert to bytes list.".format(item))
129+
130+
# Convert boolean to integer (tf.train.Example does not support bool)
131+
if v.dtype == np.bool_:
132+
v = v.astype(int)
133+
134+
v = v.flatten() # Convert v into a 1-d array
135+
if np.issubdtype(v.dtype, np.integer):
136+
return tf.train.Feature(int64_list=tf.train.Int64List(value=v))
137+
elif np.issubdtype(v.dtype, np.floating):
138+
return tf.train.Feature(float_list=tf.train.FloatList(value=v))
139+
elif is_string:
140+
v = [tf.compat.as_bytes(x) for x in v]
141+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
142+
else:
143+
raise ValueError(
144+
"Unsuported value: {}.\n"
145+
"tf.train.Feature does not support type {}. "
146+
"This may indicate that one of the FeatureConnectors received an "
147+
"unsupported value as input.".format(repr(v), repr(type(v)))
148+
)

0 commit comments

Comments
 (0)