Skip to content

Commit c726994

Browse files
committed
WIP Consolidate reads of multiple chunks in the same shard
Add test and make max gap and max coalesce size config options Code clarity and comments Test that chunk request coalescing reduces calls to store Profile a few values for coalesce_max_gap Update [doc]tests to include new sharding.read.* values document sharded read config options in user-guide/config.rst tweak logic: start new coalesced group if coalescing would exceed `coalesce_max_bytes` previous logic only started a new group if existing group was size already exceeded coalesce_max_bytes. set `mypy_path = "src"` to help pre-commit mypy find imported classes Reorder methods in sharding.py, add docstring + commenting wording docs fix docstring clarification trigger precommit on all python files changed in this pull request trying to get the ruff format that's happening locally during pre-commit to match the pre-commit run that is failing on CI. revert trigger for pre-commit ruff format
1 parent 1359aec commit c726994

File tree

6 files changed

+246
-34
lines changed

6 files changed

+246
-34
lines changed

docs/user-guide/config.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ Configuration options include the following:
3333
- Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers``
3434
- Selections of implementations of codecs, codec pipelines and buffers
3535
- Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more.
36+
- Tuning reads from sharded zarrs. When reading less than a complete shard, reads of nearby chunks
37+
within the same shard will be combined into a single request if they are less than
38+
``sharding.read.coalesce_max_gap_bytes`` apart and the combined request size is less than
39+
``sharding.read.coalesce_max_bytes``.
3640

3741
For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
3842
first register the implementations in the registry and then select them in the config.
@@ -79,4 +83,6 @@ This is the current default configuration::
7983
'default_zarr_format': 3,
8084
'json_indent': 2,
8185
'ndbuffer': 'zarr.buffer.cpu.NDBuffer',
82-
'threading': {'max_workers': None}}
86+
'sharding': {'read': {'coalesce_max_bytes': 104857600,
87+
'coalesce_max_gap_bytes': 1048576}},
88+
'threading': {'max_workers': None}}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ ignore = [
354354
[tool.mypy]
355355
python_version = "3.11"
356356
ignore_missing_imports = true
357+
mypy_path = "src"
357358
namespace_packages = false
358359

359360
strict = true

src/zarr/codecs/sharding.py

Lines changed: 142 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from zarr.core.common import (
3939
ChunkCoords,
4040
ChunkCoordsLike,
41+
concurrent_map,
4142
parse_enum,
4243
parse_named_configuration,
4344
parse_shapelike,
4445
product,
4546
)
47+
from zarr.core.config import config
4648
from zarr.core.dtype.npy.int import UInt64
4749
from zarr.core.indexing import (
4850
BasicIndexer,
@@ -198,7 +200,9 @@ async def from_bytes(
198200

199201
@classmethod
200202
def create_empty(
201-
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
203+
cls,
204+
chunks_per_shard: ChunkCoords,
205+
buffer_prototype: BufferPrototype | None = None,
202206
) -> _ShardReader:
203207
if buffer_prototype is None:
204208
buffer_prototype = default_buffer_prototype()
@@ -248,7 +252,9 @@ def merge_with_morton_order(
248252

249253
@classmethod
250254
def create_empty(
251-
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
255+
cls,
256+
chunks_per_shard: ChunkCoords,
257+
buffer_prototype: BufferPrototype | None = None,
252258
) -> _ShardBuilder:
253259
if buffer_prototype is None:
254260
buffer_prototype = default_buffer_prototype()
@@ -329,9 +335,18 @@ async def finalize(
329335
return await shard_builder.finalize(index_location, index_encoder)
330336

331337

338+
class _ChunkCoordsByteSlice(NamedTuple):
339+
"""Holds a chunk's coordinates and its byte range in a serialized shard."""
340+
341+
coords: ChunkCoords
342+
byte_slice: slice
343+
344+
332345
@dataclass(frozen=True)
333346
class ShardingCodec(
334-
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
347+
ArrayBytesCodec,
348+
ArrayBytesCodecPartialDecodeMixin,
349+
ArrayBytesCodecPartialEncodeMixin,
335350
):
336351
chunk_shape: ChunkCoords
337352
codecs: tuple[Codec, ...]
@@ -508,32 +523,21 @@ async def _decode_partial_single(
508523
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
509524

510525
# reading bytes of all requested chunks
511-
shard_dict: ShardMapping = {}
526+
shard_dict_maybe: ShardMapping | None = {}
512527
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
513528
# read entire shard
514529
shard_dict_maybe = await self._load_full_shard_maybe(
515-
byte_getter=byte_getter,
516-
prototype=chunk_spec.prototype,
517-
chunks_per_shard=chunks_per_shard,
530+
byte_getter, chunk_spec.prototype, chunks_per_shard
518531
)
519-
if shard_dict_maybe is None:
520-
return None
521-
shard_dict = shard_dict_maybe
522532
else:
523533
# read some chunks within the shard
524-
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
525-
if shard_index is None:
526-
return None
527-
shard_dict = {}
528-
for chunk_coords in all_chunk_coords:
529-
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
530-
if chunk_byte_slice:
531-
chunk_bytes = await byte_getter.get(
532-
prototype=chunk_spec.prototype,
533-
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
534-
)
535-
if chunk_bytes:
536-
shard_dict[chunk_coords] = chunk_bytes
534+
shard_dict_maybe = await self._load_partial_shard_maybe(
535+
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords
536+
)
537+
538+
if shard_dict_maybe is None:
539+
return None
540+
shard_dict = shard_dict_maybe
537541

538542
# decoding chunks and writing them into the output buffer
539543
await self.codec_pipeline.read(
@@ -615,7 +619,9 @@ async def _encode_partial_single(
615619

616620
indexer = list(
617621
get_indexer(
618-
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
622+
selection,
623+
shape=shard_shape,
624+
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
619625
)
620626
)
621627

@@ -689,7 +695,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
689695
get_pipeline_class()
690696
.from_codecs(self.index_codecs)
691697
.compute_encoded_size(
692-
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard)
698+
16 * product(chunks_per_shard),
699+
self._get_index_chunk_spec(chunks_per_shard),
693700
)
694701
)
695702

@@ -734,7 +741,8 @@ async def _load_shard_index_maybe(
734741
)
735742
else:
736743
index_bytes = await byte_getter.get(
737-
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
744+
prototype=numpy_buffer_prototype(),
745+
byte_range=SuffixByteRequest(shard_index_size),
738746
)
739747
if index_bytes is not None:
740748
return await self._decode_shard_index(index_bytes, chunks_per_shard)
@@ -748,7 +756,10 @@ async def _load_shard_index(
748756
) or _ShardIndex.create_empty(chunks_per_shard)
749757

750758
async def _load_full_shard_maybe(
751-
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords
759+
self,
760+
byte_getter: ByteGetter,
761+
prototype: BufferPrototype,
762+
chunks_per_shard: ChunkCoords,
752763
) -> _ShardReader | None:
753764
shard_bytes = await byte_getter.get(prototype=prototype)
754765

@@ -758,6 +769,110 @@ async def _load_full_shard_maybe(
758769
else None
759770
)
760771

772+
async def _load_partial_shard_maybe(
773+
self,
774+
byte_getter: ByteGetter,
775+
prototype: BufferPrototype,
776+
chunks_per_shard: ChunkCoords,
777+
all_chunk_coords: set[ChunkCoords],
778+
) -> ShardMapping | None:
779+
"""
780+
Read chunks from `byte_getter` for the case where the read is less than a full shard.
781+
Returns a mapping of chunk coordinates to bytes.
782+
"""
783+
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
784+
if shard_index is None:
785+
return None
786+
787+
chunks = [
788+
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
789+
for chunk_coords in all_chunk_coords
790+
# Drop chunks where index lookup fails
791+
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
792+
]
793+
if len(chunks) == 0:
794+
return {}
795+
796+
groups = self._coalesce_chunks(chunks)
797+
798+
shard_dicts = await concurrent_map(
799+
[(group, byte_getter, prototype) for group in groups],
800+
self._get_group_bytes,
801+
config.get("async.concurrency"),
802+
)
803+
804+
shard_dict: ShardMutableMapping = {}
805+
for d in shard_dicts:
806+
shard_dict.update(d)
807+
808+
return shard_dict
809+
810+
def _coalesce_chunks(
811+
self,
812+
chunks: list[_ChunkCoordsByteSlice],
813+
) -> list[list[_ChunkCoordsByteSlice]]:
814+
"""
815+
Combine chunks from a single shard into groups that should be read together
816+
in a single request.
817+
818+
Respects the following configuration options:
819+
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
820+
chunks to coalesce into a single group.
821+
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
822+
"""
823+
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
824+
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")
825+
826+
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)
827+
828+
groups = []
829+
current_group = [sorted_chunks[0]]
830+
831+
for chunk in sorted_chunks[1:]:
832+
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
833+
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
834+
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
835+
current_group.append(chunk)
836+
else:
837+
groups.append(current_group)
838+
current_group = [chunk]
839+
840+
groups.append(current_group)
841+
842+
return groups
843+
844+
async def _get_group_bytes(
845+
self,
846+
group: list[_ChunkCoordsByteSlice],
847+
byte_getter: ByteGetter,
848+
prototype: BufferPrototype,
849+
) -> ShardMapping:
850+
"""
851+
Reads a possibly coalesced group of one or more chunks from a shard.
852+
Returns a mapping of chunk coordinates to bytes.
853+
"""
854+
group_start = group[0].byte_slice.start
855+
group_end = group[-1].byte_slice.stop
856+
857+
# A single call to retrieve the bytes for the entire group.
858+
group_bytes = await byte_getter.get(
859+
prototype=prototype,
860+
byte_range=RangeByteRequest(group_start, group_end),
861+
)
862+
if group_bytes is None:
863+
return {}
864+
865+
# Extract the bytes corresponding to each chunk in group from group_bytes.
866+
shard_dict = {}
867+
for chunk in group:
868+
chunk_slice = slice(
869+
chunk.byte_slice.start - group_start,
870+
chunk.byte_slice.stop - group_start,
871+
)
872+
shard_dict[chunk.coords] = group_bytes[chunk_slice]
873+
874+
return shard_dict
875+
761876
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
762877
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
763878
return input_byte_length + self._shard_index_size(chunks_per_shard)

src/zarr/core/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def enable_gpu(self) -> ConfigSet:
111111
},
112112
"async": {"concurrency": 10, "timeout": None},
113113
"threading": {"max_workers": None},
114+
"sharding": {
115+
"read": {
116+
"coalesce_max_bytes": 100 * 2**20, # 100MiB
117+
"coalesce_max_gap_bytes": 2**20, # 1MiB
118+
}
119+
},
114120
"json_indent": 2,
115121
"codec_pipeline": {
116122
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",

0 commit comments

Comments
 (0)