38
38
from zarr .core .common import (
39
39
ChunkCoords ,
40
40
ChunkCoordsLike ,
41
+ concurrent_map ,
41
42
parse_enum ,
42
43
parse_named_configuration ,
43
44
parse_shapelike ,
44
45
product ,
45
46
)
47
+ from zarr .core .config import config
46
48
from zarr .core .dtype .npy .int import UInt64
47
49
from zarr .core .indexing import (
48
50
BasicIndexer ,
@@ -198,7 +200,9 @@ async def from_bytes(
198
200
199
201
@classmethod
200
202
def create_empty (
201
- cls , chunks_per_shard : ChunkCoords , buffer_prototype : BufferPrototype | None = None
203
+ cls ,
204
+ chunks_per_shard : ChunkCoords ,
205
+ buffer_prototype : BufferPrototype | None = None ,
202
206
) -> _ShardReader :
203
207
if buffer_prototype is None :
204
208
buffer_prototype = default_buffer_prototype ()
@@ -248,7 +252,9 @@ def merge_with_morton_order(
248
252
249
253
@classmethod
250
254
def create_empty (
251
- cls , chunks_per_shard : ChunkCoords , buffer_prototype : BufferPrototype | None = None
255
+ cls ,
256
+ chunks_per_shard : ChunkCoords ,
257
+ buffer_prototype : BufferPrototype | None = None ,
252
258
) -> _ShardBuilder :
253
259
if buffer_prototype is None :
254
260
buffer_prototype = default_buffer_prototype ()
@@ -329,9 +335,18 @@ async def finalize(
329
335
return await shard_builder .finalize (index_location , index_encoder )
330
336
331
337
338
+ class _ChunkCoordsByteSlice (NamedTuple ):
339
+ """Holds a chunk's coordinates and its byte range in a serialized shard."""
340
+
341
+ coords : ChunkCoords
342
+ byte_slice : slice
343
+
344
+
332
345
@dataclass (frozen = True )
333
346
class ShardingCodec (
334
- ArrayBytesCodec , ArrayBytesCodecPartialDecodeMixin , ArrayBytesCodecPartialEncodeMixin
347
+ ArrayBytesCodec ,
348
+ ArrayBytesCodecPartialDecodeMixin ,
349
+ ArrayBytesCodecPartialEncodeMixin ,
335
350
):
336
351
chunk_shape : ChunkCoords
337
352
codecs : tuple [Codec , ...]
@@ -508,32 +523,21 @@ async def _decode_partial_single(
508
523
all_chunk_coords = {chunk_coords for chunk_coords , * _ in indexed_chunks }
509
524
510
525
# reading bytes of all requested chunks
511
- shard_dict : ShardMapping = {}
526
+ shard_dict_maybe : ShardMapping | None = {}
512
527
if self ._is_total_shard (all_chunk_coords , chunks_per_shard ):
513
528
# read entire shard
514
529
shard_dict_maybe = await self ._load_full_shard_maybe (
515
- byte_getter = byte_getter ,
516
- prototype = chunk_spec .prototype ,
517
- chunks_per_shard = chunks_per_shard ,
530
+ byte_getter , chunk_spec .prototype , chunks_per_shard
518
531
)
519
- if shard_dict_maybe is None :
520
- return None
521
- shard_dict = shard_dict_maybe
522
532
else :
523
533
# read some chunks within the shard
524
- shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
525
- if shard_index is None :
526
- return None
527
- shard_dict = {}
528
- for chunk_coords in all_chunk_coords :
529
- chunk_byte_slice = shard_index .get_chunk_slice (chunk_coords )
530
- if chunk_byte_slice :
531
- chunk_bytes = await byte_getter .get (
532
- prototype = chunk_spec .prototype ,
533
- byte_range = RangeByteRequest (chunk_byte_slice [0 ], chunk_byte_slice [1 ]),
534
- )
535
- if chunk_bytes :
536
- shard_dict [chunk_coords ] = chunk_bytes
534
+ shard_dict_maybe = await self ._load_partial_shard_maybe (
535
+ byte_getter , chunk_spec .prototype , chunks_per_shard , all_chunk_coords
536
+ )
537
+
538
+ if shard_dict_maybe is None :
539
+ return None
540
+ shard_dict = shard_dict_maybe
537
541
538
542
# decoding chunks and writing them into the output buffer
539
543
await self .codec_pipeline .read (
@@ -615,7 +619,9 @@ async def _encode_partial_single(
615
619
616
620
indexer = list (
617
621
get_indexer (
618
- selection , shape = shard_shape , chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape )
622
+ selection ,
623
+ shape = shard_shape ,
624
+ chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
619
625
)
620
626
)
621
627
@@ -689,7 +695,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
689
695
get_pipeline_class ()
690
696
.from_codecs (self .index_codecs )
691
697
.compute_encoded_size (
692
- 16 * product (chunks_per_shard ), self ._get_index_chunk_spec (chunks_per_shard )
698
+ 16 * product (chunks_per_shard ),
699
+ self ._get_index_chunk_spec (chunks_per_shard ),
693
700
)
694
701
)
695
702
@@ -734,7 +741,8 @@ async def _load_shard_index_maybe(
734
741
)
735
742
else :
736
743
index_bytes = await byte_getter .get (
737
- prototype = numpy_buffer_prototype (), byte_range = SuffixByteRequest (shard_index_size )
744
+ prototype = numpy_buffer_prototype (),
745
+ byte_range = SuffixByteRequest (shard_index_size ),
738
746
)
739
747
if index_bytes is not None :
740
748
return await self ._decode_shard_index (index_bytes , chunks_per_shard )
@@ -748,7 +756,10 @@ async def _load_shard_index(
748
756
) or _ShardIndex .create_empty (chunks_per_shard )
749
757
750
758
async def _load_full_shard_maybe (
751
- self , byte_getter : ByteGetter , prototype : BufferPrototype , chunks_per_shard : ChunkCoords
759
+ self ,
760
+ byte_getter : ByteGetter ,
761
+ prototype : BufferPrototype ,
762
+ chunks_per_shard : ChunkCoords ,
752
763
) -> _ShardReader | None :
753
764
shard_bytes = await byte_getter .get (prototype = prototype )
754
765
@@ -758,6 +769,110 @@ async def _load_full_shard_maybe(
758
769
else None
759
770
)
760
771
772
+ async def _load_partial_shard_maybe (
773
+ self ,
774
+ byte_getter : ByteGetter ,
775
+ prototype : BufferPrototype ,
776
+ chunks_per_shard : ChunkCoords ,
777
+ all_chunk_coords : set [ChunkCoords ],
778
+ ) -> ShardMapping | None :
779
+ """
780
+ Read chunks from `byte_getter` for the case where the read is less than a full shard.
781
+ Returns a mapping of chunk coordinates to bytes.
782
+ """
783
+ shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
784
+ if shard_index is None :
785
+ return None
786
+
787
+ chunks = [
788
+ _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
789
+ for chunk_coords in all_chunk_coords
790
+ # Drop chunks where index lookup fails
791
+ if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
792
+ ]
793
+ if len (chunks ) == 0 :
794
+ return {}
795
+
796
+ groups = self ._coalesce_chunks (chunks )
797
+
798
+ shard_dicts = await concurrent_map (
799
+ [(group , byte_getter , prototype ) for group in groups ],
800
+ self ._get_group_bytes ,
801
+ config .get ("async.concurrency" ),
802
+ )
803
+
804
+ shard_dict : ShardMutableMapping = {}
805
+ for d in shard_dicts :
806
+ shard_dict .update (d )
807
+
808
+ return shard_dict
809
+
810
+ def _coalesce_chunks (
811
+ self ,
812
+ chunks : list [_ChunkCoordsByteSlice ],
813
+ ) -> list [list [_ChunkCoordsByteSlice ]]:
814
+ """
815
+ Combine chunks from a single shard into groups that should be read together
816
+ in a single request.
817
+
818
+ Respects the following configuration options:
819
+ - `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
820
+ chunks to coalesce into a single group.
821
+ - `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
822
+ """
823
+ max_gap_bytes = config .get ("sharding.read.coalesce_max_gap_bytes" )
824
+ coalesce_max_bytes = config .get ("sharding.read.coalesce_max_bytes" )
825
+
826
+ sorted_chunks = sorted (chunks , key = lambda c : c .byte_slice .start )
827
+
828
+ groups = []
829
+ current_group = [sorted_chunks [0 ]]
830
+
831
+ for chunk in sorted_chunks [1 :]:
832
+ gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
833
+ size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
834
+ if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
835
+ current_group .append (chunk )
836
+ else :
837
+ groups .append (current_group )
838
+ current_group = [chunk ]
839
+
840
+ groups .append (current_group )
841
+
842
+ return groups
843
+
844
+ async def _get_group_bytes (
845
+ self ,
846
+ group : list [_ChunkCoordsByteSlice ],
847
+ byte_getter : ByteGetter ,
848
+ prototype : BufferPrototype ,
849
+ ) -> ShardMapping :
850
+ """
851
+ Reads a possibly coalesced group of one or more chunks from a shard.
852
+ Returns a mapping of chunk coordinates to bytes.
853
+ """
854
+ group_start = group [0 ].byte_slice .start
855
+ group_end = group [- 1 ].byte_slice .stop
856
+
857
+ # A single call to retrieve the bytes for the entire group.
858
+ group_bytes = await byte_getter .get (
859
+ prototype = prototype ,
860
+ byte_range = RangeByteRequest (group_start , group_end ),
861
+ )
862
+ if group_bytes is None :
863
+ return {}
864
+
865
+ # Extract the bytes corresponding to each chunk in group from group_bytes.
866
+ shard_dict = {}
867
+ for chunk in group :
868
+ chunk_slice = slice (
869
+ chunk .byte_slice .start - group_start ,
870
+ chunk .byte_slice .stop - group_start ,
871
+ )
872
+ shard_dict [chunk .coords ] = group_bytes [chunk_slice ]
873
+
874
+ return shard_dict
875
+
761
876
def compute_encoded_size (self , input_byte_length : int , shard_spec : ArraySpec ) -> int :
762
877
chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
763
878
return input_byte_length + self ._shard_index_size (chunks_per_shard )
0 commit comments