38
38
from zarr .core .common import (
39
39
ChunkCoords ,
40
40
ChunkCoordsLike ,
41
+ concurrent_map ,
41
42
parse_enum ,
42
43
parse_named_configuration ,
43
44
parse_shapelike ,
44
45
product ,
45
46
)
47
+ from zarr .core .config import config
46
48
from zarr .core .indexing import (
47
49
BasicIndexer ,
48
50
SelectorTuple ,
@@ -327,6 +329,11 @@ async def finalize(
327
329
return await shard_builder .finalize (index_location , index_encoder )
328
330
329
331
332
+ class _ChunkCoordsByteSlice (NamedTuple ):
333
+ coords : ChunkCoords
334
+ byte_slice : slice
335
+
336
+
330
337
@dataclass (frozen = True )
331
338
class ShardingCodec (
332
339
ArrayBytesCodec , ArrayBytesCodecPartialDecodeMixin , ArrayBytesCodecPartialEncodeMixin
@@ -490,32 +497,21 @@ async def _decode_partial_single(
490
497
all_chunk_coords = {chunk_coords for chunk_coords , * _ in indexed_chunks }
491
498
492
499
# reading bytes of all requested chunks
493
- shard_dict : ShardMapping = {}
500
+ shard_dict_maybe : ShardMapping | None = {}
494
501
if self ._is_total_shard (all_chunk_coords , chunks_per_shard ):
495
502
# read entire shard
496
503
shard_dict_maybe = await self ._load_full_shard_maybe (
497
- byte_getter = byte_getter ,
498
- prototype = chunk_spec .prototype ,
499
- chunks_per_shard = chunks_per_shard ,
504
+ byte_getter , chunk_spec .prototype , chunks_per_shard
500
505
)
501
- if shard_dict_maybe is None :
502
- return None
503
- shard_dict = shard_dict_maybe
504
506
else :
505
507
# read some chunks within the shard
506
- shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
507
- if shard_index is None :
508
- return None
509
- shard_dict = {}
510
- for chunk_coords in all_chunk_coords :
511
- chunk_byte_slice = shard_index .get_chunk_slice (chunk_coords )
512
- if chunk_byte_slice :
513
- chunk_bytes = await byte_getter .get (
514
- prototype = chunk_spec .prototype ,
515
- byte_range = RangeByteRequest (chunk_byte_slice [0 ], chunk_byte_slice [1 ]),
516
- )
517
- if chunk_bytes :
518
- shard_dict [chunk_coords ] = chunk_bytes
508
+ shard_dict_maybe = await self ._load_partial_shard_maybe (
509
+ byte_getter , chunk_spec .prototype , chunks_per_shard , all_chunk_coords
510
+ )
511
+
512
+ if shard_dict_maybe is None :
513
+ return None
514
+ shard_dict = shard_dict_maybe
519
515
520
516
# decoding chunks and writing them into the output buffer
521
517
await self .codec_pipeline .read (
@@ -537,6 +533,96 @@ async def _decode_partial_single(
537
533
else :
538
534
return out
539
535
536
+ async def _load_partial_shard_maybe (
537
+ self ,
538
+ byte_getter : ByteGetter ,
539
+ prototype : BufferPrototype ,
540
+ chunks_per_shard : ChunkCoords ,
541
+ all_chunk_coords : set [ChunkCoords ],
542
+ ) -> ShardMapping | None :
543
+ shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
544
+ if shard_index is None :
545
+ return None
546
+
547
+ chunks = [
548
+ _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
549
+ for chunk_coords in all_chunk_coords
550
+ if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
551
+ ]
552
+ if len (chunks ) == 0 :
553
+ return {}
554
+
555
+ groups = self ._coalesce_chunks (chunks )
556
+
557
+ shard_dicts = await concurrent_map (
558
+ [(group , byte_getter , prototype ) for group in groups ],
559
+ self ._get_group_bytes ,
560
+ config .get ("async.concurrency" ),
561
+ )
562
+
563
+ shard_dict : ShardMutableMapping = {}
564
+ for d in shard_dicts :
565
+ shard_dict .update (d )
566
+
567
+ return shard_dict
568
+
569
+ def _coalesce_chunks (
570
+ self ,
571
+ chunks : list [_ChunkCoordsByteSlice ],
572
+ max_gap_bytes : int = 2 ** 20 , # 1MiB
573
+ coalesce_max_bytes : int = 100 * 2 ** 20 , # 100MiB
574
+ ) -> list [list [_ChunkCoordsByteSlice ]]:
575
+ sorted_chunks = sorted (chunks , key = lambda c : c .byte_slice .start )
576
+
577
+ groups = []
578
+ current_group = [sorted_chunks [0 ]]
579
+
580
+ for chunk in sorted_chunks [1 :]:
581
+ gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
582
+ current_group_size = (
583
+ current_group [- 1 ].byte_slice .stop - current_group [0 ].byte_slice .start
584
+ )
585
+ if gap_to_chunk < max_gap_bytes and current_group_size < coalesce_max_bytes :
586
+ current_group .append (chunk )
587
+ else :
588
+ groups .append (current_group )
589
+ current_group = [chunk ]
590
+
591
+ groups .append (current_group )
592
+
593
+ from pprint import pprint
594
+
595
+ pprint (
596
+ [
597
+ f"{ len (g )} chunks, { (g [- 1 ].byte_slice .stop - g [0 ].byte_slice .start ) / 1e6 :.1f} MB"
598
+ for g in groups
599
+ ]
600
+ )
601
+
602
+ return groups
603
+
604
+ async def _get_group_bytes (
605
+ self ,
606
+ group : list [_ChunkCoordsByteSlice ],
607
+ byte_getter : ByteGetter ,
608
+ prototype : BufferPrototype ,
609
+ ) -> ShardMapping :
610
+ group_start = group [0 ].byte_slice .start
611
+ group_end = group [- 1 ].byte_slice .stop
612
+
613
+ group_bytes = await byte_getter .get (
614
+ prototype = prototype ,
615
+ byte_range = RangeByteRequest (group_start , group_end ),
616
+ )
617
+ if group_bytes is None :
618
+ return {}
619
+
620
+ shard_dict = {}
621
+ for chunk in group :
622
+ s = slice (chunk .byte_slice .start - group_start , chunk .byte_slice .stop - group_start )
623
+ shard_dict [chunk .coords ] = group_bytes [s ]
624
+ return shard_dict
625
+
540
626
async def _encode_single (
541
627
self ,
542
628
shard_array : NDBuffer ,
0 commit comments