Skip to content

Commit 40dde0b

Browse files
pierrot0copybara-github
authored andcommitted
tfrecords_writer: new mechanism to write tfrecord files.
PiperOrigin-RevId: 251858270
1 parent 3065442 commit 40dde0b

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""To write records into sharded tfrecord files."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
23+
from absl import logging
24+
import tensorflow as tf
25+
26+
from tensorflow_datasets.core import example_serializer
27+
from tensorflow_datasets.core import shuffle
28+
from tensorflow_datasets.core import utils
29+
30+
MIN_SHARD_SIZE = 64<<20 # 64 MiB
31+
MAX_SHARD_SIZE = 1024<<20 # 2 GiB
32+
33+
# TFRECORD overheads.
34+
# https://github.com/tensorflow/tensorflow/blob/27325fabed898880fa1b33a04d4b125a6ef4bbc8/tensorflow/core/lib/io/record_writer.h#L104
35+
TFRECORD_REC_OVERHEAD = 16
36+
37+
38+
class _TFRecordWriter(object):
39+
"""Writes examples in given order, to specified number of shards."""
40+
41+
def __init__(self, path, num_examples, number_of_shards):
42+
"""Init _TFRecordWriter instance.
43+
44+
Args:
45+
path: where to write TFRecord file.
46+
num_examples (int): the number of examples which will be written.
47+
number_of_shards (int): the desired number of shards. The examples will be
48+
evenly distributed among that number of shards.
49+
"""
50+
self._path = path
51+
self._shard_boundaries = [round(num_examples * (float(i)/number_of_shards))
52+
for i in range(1, number_of_shards+1)]
53+
self._num_shards = number_of_shards
54+
self._shards_length = [] # The number of Example in each shard.
55+
self._number_written_examples = 0
56+
self._init_new_shard(0)
57+
58+
def _init_new_shard(self, new_shard_number):
59+
self._current_shard = new_shard_number
60+
self._current_shard_limit = self._shard_boundaries.pop(0)
61+
self._current_writer = None
62+
self._current_shard_length = 0
63+
64+
def _flush_current_shard(self):
65+
if self._current_shard_length == 0:
66+
return # Nothing was written.
67+
self._current_writer.flush()
68+
self._current_writer.close()
69+
self._current_writer = None
70+
self._shards_length.append(self._current_shard_length)
71+
72+
def _get_path(self, index):
73+
return '%s-%05d-of-%05d' % (self._path, index, self._num_shards)
74+
75+
def _write(self, serialized_example):
76+
if not self._current_writer:
77+
fpath = self._get_path(self._current_shard)
78+
logging.info('Creating file %s', fpath)
79+
self._current_writer = tf.io.TFRecordWriter(fpath)
80+
self._current_writer.write(serialized_example)
81+
self._current_shard_length += 1
82+
self._number_written_examples += 1
83+
84+
def write(self, serialized_example):
85+
"""Write given example, starts new shard when needed."""
86+
if self._number_written_examples >= self._current_shard_limit:
87+
self._flush_current_shard()
88+
self._init_new_shard(self._current_shard + 1)
89+
self._write(serialized_example)
90+
91+
def finalize(self):
92+
"""Finalize files, returns list containing the length of each shard."""
93+
self._flush_current_shard()
94+
return self._shards_length
95+
96+
97+
def _get_number_shards(total_size, num_examples):
98+
"""Returns number of shards for num_examples of total_size in bytes.
99+
100+
Each shard should be at least 128MB.
101+
A pod has 16*16=256 TPU devices containing 1024 TPU chips (2048 cores).
102+
So if the dataset is large enough, we want the number of shards to be a
103+
multiple of 1024, but with shards as big as possible.
104+
If the dataset is too small, we want the number of shards to be a power
105+
of two so it distributes better on smaller TPU configs (8, 16, 32, ... cores).
106+
107+
Args:
108+
total_size: the size of the data (serialized, not couting any overhead).
109+
num_examples: the number of records in the data.
110+
111+
Returns:
112+
number of shards to use.
113+
"""
114+
total_size += num_examples * TFRECORD_REC_OVERHEAD
115+
max_shards_number = total_size // MIN_SHARD_SIZE
116+
min_shards_number = total_size // MAX_SHARD_SIZE
117+
if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
118+
return 1024
119+
elif min_shards_number > 1024:
120+
i = 2
121+
while True:
122+
n = 1024 * i
123+
if n >= min_shards_number and num_examples >= n:
124+
return n
125+
i += 1
126+
else:
127+
for n in [512, 256, 128, 64, 32, 16, 8, 4, 2]:
128+
if min_shards_number <= n <= max_shards_number and num_examples >= n:
129+
return n
130+
return 1
131+
132+
133+
class Writer(object):
134+
"""Shuffles and writes Examples to sharded TFRecord files.
135+
136+
The number of shards is computed automatically.
137+
138+
This class is a replacement for file_format_adapter.TFRecordExampleAdapter,
139+
which will eventually be deleted.
140+
"""
141+
142+
def __init__(self, example_specs, path):
143+
self._serializer = example_serializer.ExampleSerializer(example_specs)
144+
self._shuffler = shuffle.Shuffler(os.path.dirname(path))
145+
self._num_examples = 0
146+
self._path = path
147+
148+
def write(self, key, example):
149+
"""Writes given Example.
150+
151+
The given example is not directly written to the tfrecord file, but to a
152+
temporary file (or memory). The finalize() method does write the tfrecord
153+
files.
154+
155+
Args:
156+
key (int|bytes): the key associated with the example. Used for shuffling.
157+
example: the Example to write to the tfrecord file.
158+
"""
159+
serialized_example = self._serializer.serialize_example(example)
160+
self._shuffler.add(key, serialized_example)
161+
self._num_examples += 1
162+
163+
def finalize(self):
164+
"""Effectively writes examples to the tfrecord files."""
165+
print('Shuffling and writing examples to %s' % self._path)
166+
number_of_shards = _get_number_shards(self._shuffler.size,
167+
self._num_examples)
168+
writer = _TFRecordWriter(self._path, self._num_examples, number_of_shards)
169+
for serialized_example in utils.tqdm(
170+
self._shuffler, total=self._num_examples,
171+
unit=' examples', leave=False):
172+
writer.write(serialized_example)
173+
shard_lengths = writer.finalize()
174+
logging.info('Done writing %s. Shard lengths: %s',
175+
self._path, shard_lengths)
176+
return shard_lengths
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for tensorflow_datasets.core.tfrecords_writer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
from absl.testing import absltest
25+
import tensorflow as tf
26+
from tensorflow_datasets import testing
27+
from tensorflow_datasets.core import dataset_utils
28+
from tensorflow_datasets.core import example_serializer
29+
from tensorflow_datasets.core import tfrecords_writer
30+
31+
32+
class GetNumberShardsTest(testing.TestCase):
33+
34+
def test_imagenet_train(self):
35+
size = 137<<30 # 137 GiB
36+
num_examples = 1281167
37+
n = tfrecords_writer._get_number_shards(size, num_examples)
38+
self.assertEqual(n, 1024)
39+
40+
def test_imagenet_evaluation(self):
41+
size = 6300 * (1<<20) # 6.3 GiB
42+
num_examples = 50000
43+
n = tfrecords_writer._get_number_shards(size, num_examples)
44+
self.assertEqual(n, 64)
45+
46+
def test_verylarge_few_examples(self):
47+
size = 52<<30 # 52 GiB
48+
num_examples = 512
49+
n = tfrecords_writer._get_number_shards(size, num_examples)
50+
self.assertEqual(n, 512)
51+
52+
def test_xxl(self):
53+
size = 10<<40 # 10 TiB
54+
num_examples = 10**9 # 1G
55+
n = tfrecords_writer._get_number_shards(size, num_examples)
56+
self.assertEqual(n, 11264)
57+
58+
def test_xs(self):
59+
size = 100<<20 # 100 MiB
60+
num_examples = 100 * 10**3 # 100K
61+
n = tfrecords_writer._get_number_shards(size, num_examples)
62+
self.assertEqual(n, 1)
63+
64+
def test_m(self):
65+
size = 400<<20 # 499 MiB
66+
num_examples = 200 * 10**3 # 200K
67+
n = tfrecords_writer._get_number_shards(size, num_examples)
68+
self.assertEqual(n, 4)
69+
70+
71+
def _read_records(path):
72+
"""Returns (files_names, list_of_records_in_each_file)."""
73+
fnames = sorted(tf.io.gfile.listdir(path))
74+
all_recs = []
75+
for fname in fnames:
76+
fpath = os.path.join(path, fname)
77+
recs = list(dataset_utils.as_numpy(tf.data.TFRecordDataset(fpath)))
78+
all_recs.append(recs)
79+
return fnames, all_recs
80+
81+
82+
class _DummySerializer(object):
83+
84+
def __init__(self, specs):
85+
del specs
86+
87+
def serialize_example(self, example):
88+
return bytes(example)
89+
90+
91+
class WriterTest(testing.TestCase):
92+
93+
@absltest.mock.patch.object(
94+
example_serializer, 'ExampleSerializer', _DummySerializer)
95+
def test_write(self):
96+
"""Writes 8 records in 5 shards.
97+
98+
Number of records is evenly distributed (2-1-2-1-2).
99+
"""
100+
path = os.path.join(self.tmp_dir, 'foo.tfrecord')
101+
writer = tfrecords_writer.Writer('some spec', path)
102+
to_write = [
103+
(1, b'a'), (2, b'b'),
104+
(3, b'c'),
105+
(4, b'd'), (5, b'e'),
106+
(6, b'f'),
107+
(7, b'g'), (8, b'h'),
108+
]
109+
for key, record in to_write:
110+
writer.write(key, record)
111+
with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards',
112+
return_value=5):
113+
shards_length = writer.finalize()
114+
self.assertEqual(shards_length, [2, 1, 2, 1, 2])
115+
written_files, all_recs = _read_records(self.tmp_dir)
116+
self.assertEqual(written_files,
117+
['foo.tfrecord-0000%s-of-00005' % i for i in range(5)])
118+
self.assertEqual(all_recs, [
119+
[b'f', b'e'], [b'b'], [b'a', b'g'], [b'h'], [b'c', b'd'],
120+
])
121+
122+
if __name__ == '__main__':
123+
testing.test_main()

0 commit comments

Comments
 (0)