Skip to content

Commit 361a292

Browse files
committed
WIP Consolidate reads of multiple chunks in the same shard
1 parent ef39891 commit 361a292

File tree

1 file changed

+106
-20
lines changed

1 file changed

+106
-20
lines changed

src/zarr/codecs/sharding.py

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from zarr.core.common import (
3939
ChunkCoords,
4040
ChunkCoordsLike,
41+
concurrent_map,
4142
parse_enum,
4243
parse_named_configuration,
4344
parse_shapelike,
4445
product,
4546
)
47+
from zarr.core.config import config
4648
from zarr.core.indexing import (
4749
BasicIndexer,
4850
SelectorTuple,
@@ -327,6 +329,11 @@ async def finalize(
327329
return await shard_builder.finalize(index_location, index_encoder)
328330

329331

332+
class _ChunkCoordsByteSlice(NamedTuple):
333+
coords: ChunkCoords
334+
byte_slice: slice
335+
336+
330337
@dataclass(frozen=True)
331338
class ShardingCodec(
332339
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
@@ -490,32 +497,21 @@ async def _decode_partial_single(
490497
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
491498

492499
# reading bytes of all requested chunks
493-
shard_dict: ShardMapping = {}
500+
shard_dict_maybe: ShardMapping | None = {}
494501
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
495502
# read entire shard
496503
shard_dict_maybe = await self._load_full_shard_maybe(
497-
byte_getter=byte_getter,
498-
prototype=chunk_spec.prototype,
499-
chunks_per_shard=chunks_per_shard,
504+
byte_getter, chunk_spec.prototype, chunks_per_shard
500505
)
501-
if shard_dict_maybe is None:
502-
return None
503-
shard_dict = shard_dict_maybe
504506
else:
505507
# read some chunks within the shard
506-
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
507-
if shard_index is None:
508-
return None
509-
shard_dict = {}
510-
for chunk_coords in all_chunk_coords:
511-
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
512-
if chunk_byte_slice:
513-
chunk_bytes = await byte_getter.get(
514-
prototype=chunk_spec.prototype,
515-
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
516-
)
517-
if chunk_bytes:
518-
shard_dict[chunk_coords] = chunk_bytes
508+
shard_dict_maybe = await self._load_partial_shard_maybe(
509+
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords
510+
)
511+
512+
if shard_dict_maybe is None:
513+
return None
514+
shard_dict = shard_dict_maybe
519515

520516
# decoding chunks and writing them into the output buffer
521517
await self.codec_pipeline.read(
@@ -537,6 +533,96 @@ async def _decode_partial_single(
537533
else:
538534
return out
539535

536+
async def _load_partial_shard_maybe(
537+
self,
538+
byte_getter: ByteGetter,
539+
prototype: BufferPrototype,
540+
chunks_per_shard: ChunkCoords,
541+
all_chunk_coords: set[ChunkCoords],
542+
) -> ShardMapping | None:
543+
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
544+
if shard_index is None:
545+
return None
546+
547+
chunks = [
548+
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
549+
for chunk_coords in all_chunk_coords
550+
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
551+
]
552+
if len(chunks) == 0:
553+
return {}
554+
555+
groups = self._coalesce_chunks(chunks)
556+
557+
shard_dicts = await concurrent_map(
558+
[(group, byte_getter, prototype) for group in groups],
559+
self._get_group_bytes,
560+
config.get("async.concurrency"),
561+
)
562+
563+
shard_dict: ShardMutableMapping = {}
564+
for d in shard_dicts:
565+
shard_dict.update(d)
566+
567+
return shard_dict
568+
569+
def _coalesce_chunks(
570+
self,
571+
chunks: list[_ChunkCoordsByteSlice],
572+
max_gap_bytes: int = 2**20, # 1MiB
573+
coalesce_max_bytes: int = 100 * 2**20, # 100MiB
574+
) -> list[list[_ChunkCoordsByteSlice]]:
575+
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)
576+
577+
groups = []
578+
current_group = [sorted_chunks[0]]
579+
580+
for chunk in sorted_chunks[1:]:
581+
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
582+
current_group_size = (
583+
current_group[-1].byte_slice.stop - current_group[0].byte_slice.start
584+
)
585+
if gap_to_chunk < max_gap_bytes and current_group_size < coalesce_max_bytes:
586+
current_group.append(chunk)
587+
else:
588+
groups.append(current_group)
589+
current_group = [chunk]
590+
591+
groups.append(current_group)
592+
593+
from pprint import pprint
594+
595+
pprint(
596+
[
597+
f"{len(g)} chunks, {(g[-1].byte_slice.stop - g[0].byte_slice.start) / 1e6:.1f}MB"
598+
for g in groups
599+
]
600+
)
601+
602+
return groups
603+
604+
async def _get_group_bytes(
605+
self,
606+
group: list[_ChunkCoordsByteSlice],
607+
byte_getter: ByteGetter,
608+
prototype: BufferPrototype,
609+
) -> ShardMapping:
610+
group_start = group[0].byte_slice.start
611+
group_end = group[-1].byte_slice.stop
612+
613+
group_bytes = await byte_getter.get(
614+
prototype=prototype,
615+
byte_range=RangeByteRequest(group_start, group_end),
616+
)
617+
if group_bytes is None:
618+
return {}
619+
620+
shard_dict = {}
621+
for chunk in group:
622+
s = slice(chunk.byte_slice.start - group_start, chunk.byte_slice.stop - group_start)
623+
shard_dict[chunk.coords] = group_bytes[s]
624+
return shard_dict
625+
540626
async def _encode_single(
541627
self,
542628
shard_array: NDBuffer,

0 commit comments

Comments
 (0)