From b947eafd8cd1f35ce3e7db4783d71f9656e0cb02 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Thu, 13 Feb 2025 01:47:40 -0800 Subject: [PATCH] Use the max size of serialized examples to find a safe number of shards If we know the max size of serialized examples, then we can account for the worst case scenario where one shard would get only examples of the max size. This hopefully should prevent users running into problems with having too big shards. PiperOrigin-RevId: 726377778 --- .../huggingface_dataset_builder.py | 1 + tensorflow_datasets/core/reader_test.py | 1 + tensorflow_datasets/core/shuffle.py | 6 +- tensorflow_datasets/core/utils/shard_utils.py | 37 ++++++- .../core/utils/shard_utils_test.py | 100 ++++++++++++++++-- tensorflow_datasets/core/writer.py | 49 +++++++-- tensorflow_datasets/core/writer_test.py | 2 + 7 files changed, 170 insertions(+), 26 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 6be1af3d0ed..9fad62f5aa2 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -406,6 +406,7 @@ def _compute_shard_specs( # HF split size is good enough for estimating the number of shards. num_shards = shard_utils.ShardConfig.calculate_number_shards( total_size=hf_split_info.num_bytes, + max_example_size=None, num_examples=hf_split_info.num_examples, uses_precise_sharding=False, ) diff --git a/tensorflow_datasets/core/reader_test.py b/tensorflow_datasets/core/reader_test.py index e478e26bfd7..fdc88ad42f7 100644 --- a/tensorflow_datasets/core/reader_test.py +++ b/tensorflow_datasets/core/reader_test.py @@ -97,6 +97,7 @@ def _write_tfrecord(self, split_name, shards_number, records): shard_specs = writer_lib._get_shard_specs( num_examples=num_examples, total_size=0, + max_example_size=None, bucket_lengths=[num_examples], filename_template=filename_template, shard_config=shard_utils.ShardConfig(num_shards=shards_number), diff --git a/tensorflow_datasets/core/shuffle.py b/tensorflow_datasets/core/shuffle.py index 2df14df0474..a54adbfa586 100644 --- a/tensorflow_datasets/core/shuffle.py +++ b/tensorflow_datasets/core/shuffle.py @@ -244,7 +244,7 @@ def __init__( self._total_bytes = 0 # To keep data in memory until enough data has been gathered. self._in_memory = True - self._mem_buffer = [] + self._mem_buffer: list[type_utils.KeySerializedExample] = [] self._seen_keys: set[int] = set() self._num_examples = 0 @@ -272,10 +272,10 @@ def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None: if self._total_bytes > MAX_MEM_BUFFER_SIZE: for hkey, data in self._mem_buffer: self._add_to_bucket(hkey, data) - self._mem_buffer = None + self._mem_buffer = [] self._in_memory = False - def add(self, key: type_utils.Key, data: bytes) -> bool: + def add(self, key: type_utils.Key, data: bytes) -> None: """Add (key, data) to shuffler.""" if self._read_only: raise AssertionError('add() cannot be called after __iter__.') diff --git a/tensorflow_datasets/core/utils/shard_utils.py b/tensorflow_datasets/core/utils/shard_utils.py index a5a700608e2..db018e8a2e5 100644 --- a/tensorflow_datasets/core/utils/shard_utils.py +++ b/tensorflow_datasets/core/utils/shard_utils.py @@ -57,19 +57,22 @@ class ShardConfig: def calculate_number_shards( cls, total_size: int, + max_example_size: int | Sequence[int] | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: """Returns number of shards for num_examples of total_size in bytes. Args: - total_size: the size of the data (serialized, not couting any overhead). + total_size: the size of the data (serialized, not counting any overhead). + max_example_size: the maximum size of a single example (serialized, not + counting any overhead). num_examples: the number of records in the data. uses_precise_sharding: whether a mechanism is used to exactly control how many examples go in each shard. """ - total_size += num_examples * cls.overhead - max_shards_number = total_size // cls.min_shard_size + total_overhead = num_examples * cls.overhead + total_size_with_overhead = total_size + total_overhead if uses_precise_sharding: max_shard_size = cls.max_shard_size else: @@ -77,7 +80,24 @@ def calculate_number_shards( # shard (called 'precise sharding' here), we use a smaller max shard size # so that the pipeline doesn't fail if a shard gets some more examples. max_shard_size = 0.9 * cls.max_shard_size - min_shards_number = total_size // max_shard_size + max_shard_size = max(1, max_shard_size) + + if max_example_size is None: + min_shards_number = max(1, total_size_with_overhead // max_shard_size) + max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size) + else: + if isinstance(max_example_size, Sequence): + if len(max_example_size) == 1: + max_example_size = max_example_size[0] + else: + raise ValueError( + 'max_example_size must be a single value or None, got' + f' {max_example_size}' + ) + pessimistic_total_size = num_examples * (max_example_size + cls.overhead) + min_shards_number = max(1, pessimistic_total_size // max_shard_size) + max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size) + if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024: return 1024 elif min_shards_number > 1024: @@ -96,15 +116,22 @@ def calculate_number_shards( def get_number_shards( self, total_size: int, + max_example_size: int | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: if self.num_shards: return self.num_shards return self.calculate_number_shards( - total_size, num_examples, uses_precise_sharding + total_size=total_size, + max_example_size=max_example_size, + num_examples=num_examples, + uses_precise_sharding=uses_precise_sharding, ) + def replace(self, **kwargs: Any) -> ShardConfig: + return dataclasses.replace(self, **kwargs) + def get_shard_boundaries( num_examples: int, diff --git a/tensorflow_datasets/core/utils/shard_utils_test.py b/tensorflow_datasets/core/utils/shard_utils_test.py index 1882b178e37..325d1d15cf1 100644 --- a/tensorflow_datasets/core/utils/shard_utils_test.py +++ b/tensorflow_datasets/core/utils/shard_utils_test.py @@ -22,16 +22,94 @@ class ShardConfigTest(parameterized.TestCase): @parameterized.named_parameters( - ('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024), - ('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64), - ('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512), - ('xxl, 10 TiB', 10 << 40, 10**9, True, 11264), - ('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808), - ('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1), - ('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4), + dict( + testcase_name='imagenet train, 137 GiB', + total_size=137 << 30, + num_examples=1281167, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1024, + ), + dict( + testcase_name='imagenet evaluation, 6.3 GiB', + total_size=6300 * (1 << 20), + num_examples=50000, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=64, + ), + dict( + testcase_name='very large, but few examples, 52 GiB', + total_size=52 << 30, + num_examples=512, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=512, + ), + dict( + testcase_name='xxl, 10 TiB', + total_size=10 << 40, + num_examples=10**9, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=11264, + ), + dict( + testcase_name='xxl, 10 PiB, 100B examples', + total_size=10 << 50, + num_examples=10**11, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=10487808, + ), + dict( + testcase_name='xs, 100 MiB, 100K records', + total_size=10 << 20, + num_examples=100 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1, + ), + dict( + testcase_name='m, 499 MiB, 200K examples', + total_size=400 << 20, + num_examples=200 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=4, + ), + dict( + testcase_name='100GiB, even example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=1024, + ), + dict( + testcase_name='100GiB, uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=4 * 1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=4096, + ), + dict( + testcase_name='100GiB, very uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=16 * 1000, # Max example size is 16x the average bytes + uses_precise_sharding=True, + expected_num_shards=15360, + ), ) def test_get_number_shards_default_config( - self, total_size, num_examples, uses_precise_sharding, expected_num_shards + self, + total_size: int, + num_examples: int, + uses_precise_sharding: bool, + max_size: int, + expected_num_shards: int, ): shard_config = shard_utils.ShardConfig() self.assertEqual( @@ -39,6 +117,7 @@ def test_get_number_shards_default_config( shard_config.get_number_shards( total_size=total_size, num_examples=num_examples, + max_example_size=max_size, # max(1, total_size // num_examples), uses_precise_sharding=uses_precise_sharding, ), ) @@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self): self.assertEqual( 42, shard_config.get_number_shards( - total_size=100, num_examples=1, uses_precise_sharding=True + total_size=100, + max_example_size=100, + num_examples=1, + uses_precise_sharding=True, ), ) diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index aa756bbcfb5..abd461717ac 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike: def _get_shard_specs( num_examples: int, total_size: int, + max_example_size: int | None, bucket_lengths: Sequence[int], filename_template: naming.ShardedFileTemplate, shard_config: shard_utils.ShardConfig, @@ -123,13 +124,18 @@ def _get_shard_specs( """Returns list of _ShardSpec instances, corresponding to shards to write. Args: - num_examples: int, number of examples in split. - total_size: int (bytes), sum of example sizes. + num_examples: number of examples in split. + total_size: total size in bytes, i.e., the sum of example sizes. + max_example_size: maximum size in bytes of a single example. bucket_lengths: list of ints, number of examples in each bucket. filename_template: template to format sharded filenames. shard_config: the configuration for creating shards. """ - num_shards = shard_config.get_number_shards(total_size, num_examples) + num_shards = shard_config.get_number_shards( + total_size=total_size, + max_example_size=max_example_size, + num_examples=num_examples, + ) shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = [str(i) for i in range(len(bucket_lengths))] @@ -350,6 +356,7 @@ def __init__( self._filename_template = filename_template self._shard_config = shard_config or shard_utils.ShardConfig() self._example_writer = example_writer + self._max_example_size = 0 def write(self, key: int | bytes, example: Example): """Writes given example. @@ -363,6 +370,9 @@ def write(self, key: int | bytes, example: Example): """ serialized_example = self._serializer.serialize_example(example=example) self._shuffler.add(key, serialized_example) + self._max_example_size = max( + self._max_example_size, len(serialized_example) + ) def finalize(self) -> tuple[list[int], int]: """Effectively writes examples to the shards.""" @@ -372,6 +382,7 @@ def finalize(self) -> tuple[list[int], int]: shard_specs = _get_shard_specs( num_examples=self._shuffler.num_examples, total_size=self._shuffler.size, + max_example_size=self._max_example_size, bucket_lengths=self._shuffler.bucket_lengths, filename_template=self._filename_template, shard_config=self._shard_config, @@ -589,10 +600,13 @@ def _write_final_shard( id=shard_id, num_examples=len(example_by_key), size=shard_size ) - def _number_of_shards(self, num_examples: int, total_size: int) -> int: + def _number_of_shards( + self, num_examples: int, total_size: int, max_example_size: int + ) -> int: """Returns the number of shards.""" num_shards = self._shard_config.get_number_shards( total_size=total_size, + max_example_size=max_example_size, num_examples=num_examples, uses_precise_sharding=False, ) @@ -658,16 +672,26 @@ def write_from_pcollection(self, examples_pcollection): | "CountExamples" >> beam.combiners.Count.Globally() | "CheckValidNumExamples" >> beam.Map(self._check_num_examples) ) + serialized_example_sizes = ( + serialized_examples | beam.Values() | beam.Map(len) + ) total_size = beam.pvalue.AsSingleton( - serialized_examples - | beam.Values() - | beam.Map(len) - | "TotalSize" >> beam.CombineGlobally(sum) + serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum) + ) + + max_example_size = beam.pvalue.AsSingleton( + serialized_example_sizes + | "TopExampleSize" >> beam.combiners.Top.Largest(1) + | "MaxExampleSize" >> beam.CombineGlobally(_get_max_size) ) ideal_num_shards = beam.pvalue.AsSingleton( num_examples | "NumberOfShards" - >> beam.Map(self._number_of_shards, total_size=total_size) + >> beam.Map( + self._number_of_shards, + total_size=total_size, + max_example_size=max_example_size, + ) ) examples_per_shard = ( @@ -826,3 +850,10 @@ def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]: ) return shard_lengths, total_size_bytes + + +def _get_max_size(sizes: Iterable[int]) -> int | None: + sizes = list(sizes) + if not sizes: + return None + return max(sizes) diff --git a/tensorflow_datasets/core/writer_test.py b/tensorflow_datasets/core/writer_test.py index c0b5fd30068..79b36124636 100644 --- a/tensorflow_datasets/core/writer_test.py +++ b/tensorflow_datasets/core/writer_test.py @@ -48,6 +48,7 @@ def test_1bucket_6shards(self): filetype_suffix='tfrecord', ), shard_config=shard_utils.ShardConfig(num_shards=6), + max_example_size=2, ) self.assertEqual( specs, @@ -134,6 +135,7 @@ def test_4buckets_2shards(self): filetype_suffix='tfrecord', ), shard_config=shard_utils.ShardConfig(num_shards=2), + max_example_size=2, ) self.assertEqual( specs,