@@ -90,9 +90,9 @@ async def get(
90
90
self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
91
91
) -> Buffer | None :
92
92
assert byte_range is None , "byte_range is not supported within shards"
93
- assert prototype == default_buffer_prototype (), (
94
- f" prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype ()} "
95
- )
93
+ assert (
94
+ prototype == default_buffer_prototype ()
95
+ ), f"prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype () } "
96
96
return self .shard_dict .get (self .chunk_coords )
97
97
98
98
@@ -124,7 +124,9 @@ def chunks_per_shard(self) -> ChunkCoords:
124
124
def _localize_chunk (self , chunk_coords : ChunkCoords ) -> ChunkCoords :
125
125
return tuple (
126
126
chunk_i % shard_i
127
- for chunk_i , shard_i in zip (chunk_coords , self .offsets_and_lengths .shape , strict = False )
127
+ for chunk_i , shard_i in zip (
128
+ chunk_coords , self .offsets_and_lengths .shape , strict = False
129
+ )
128
130
)
129
131
130
132
def is_all_empty (self ) -> bool :
@@ -141,7 +143,9 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
141
143
else :
142
144
return (int (chunk_start ), int (chunk_start + chunk_len ))
143
145
144
- def set_chunk_slice (self , chunk_coords : ChunkCoords , chunk_slice : slice | None ) -> None :
146
+ def set_chunk_slice (
147
+ self , chunk_coords : ChunkCoords , chunk_slice : slice | None
148
+ ) -> None :
145
149
localized_chunk = self ._localize_chunk (chunk_coords )
146
150
if chunk_slice is None :
147
151
self .offsets_and_lengths [localized_chunk ] = (MAX_UINT_64 , MAX_UINT_64 )
@@ -163,7 +167,11 @@ def is_dense(self, chunk_byte_length: int) -> bool:
163
167
164
168
# Are all non-empty offsets unique?
165
169
if len (
166
- {offset for offset , _ in sorted_offsets_and_lengths if offset != MAX_UINT_64 }
170
+ {
171
+ offset
172
+ for offset , _ in sorted_offsets_and_lengths
173
+ if offset != MAX_UINT_64
174
+ }
167
175
) != len (sorted_offsets_and_lengths ):
168
176
return False
169
177
@@ -267,7 +275,9 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None:
267
275
chunk_start = len (self .buf )
268
276
chunk_length = len (value )
269
277
self .buf += value
270
- self .index .set_chunk_slice (chunk_coords , slice (chunk_start , chunk_start + chunk_length ))
278
+ self .index .set_chunk_slice (
279
+ chunk_coords , slice (chunk_start , chunk_start + chunk_length )
280
+ )
271
281
272
282
def __delitem__ (self , chunk_coords : ChunkCoords ) -> None :
273
283
raise NotImplementedError
@@ -281,7 +291,9 @@ async def finalize(
281
291
if index_location == ShardingCodecIndexLocation .start :
282
292
empty_chunks_mask = self .index .offsets_and_lengths [..., 0 ] == MAX_UINT_64
283
293
self .index .offsets_and_lengths [~ empty_chunks_mask , 0 ] += len (index_bytes )
284
- index_bytes = await index_encoder (self .index ) # encode again with corrected offsets
294
+ index_bytes = await index_encoder (
295
+ self .index
296
+ ) # encode again with corrected offsets
285
297
out_buf = index_bytes + self .buf
286
298
else :
287
299
out_buf = self .buf + index_bytes
@@ -359,7 +371,8 @@ def __init__(
359
371
chunk_shape : ChunkCoordsLike ,
360
372
codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (),),
361
373
index_codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (), Crc32cCodec ()),
362
- index_location : ShardingCodecIndexLocation | str = ShardingCodecIndexLocation .end ,
374
+ index_location : ShardingCodecIndexLocation
375
+ | str = ShardingCodecIndexLocation .end ,
363
376
) -> None :
364
377
chunk_shape_parsed = parse_shapelike (chunk_shape )
365
378
codecs_parsed = parse_codecs (codecs )
@@ -389,7 +402,9 @@ def __setstate__(self, state: dict[str, Any]) -> None:
389
402
object .__setattr__ (self , "chunk_shape" , parse_shapelike (config ["chunk_shape" ]))
390
403
object .__setattr__ (self , "codecs" , parse_codecs (config ["codecs" ]))
391
404
object .__setattr__ (self , "index_codecs" , parse_codecs (config ["index_codecs" ]))
392
- object .__setattr__ (self , "index_location" , parse_index_location (config ["index_location" ]))
405
+ object .__setattr__ (
406
+ self , "index_location" , parse_index_location (config ["index_location" ])
407
+ )
393
408
394
409
# Use instance-local lru_cache to avoid memory leaks
395
410
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -418,7 +433,9 @@ def to_dict(self) -> dict[str, JSON]:
418
433
419
434
def evolve_from_array_spec (self , array_spec : ArraySpec ) -> Self :
420
435
shard_spec = self ._get_chunk_spec (array_spec )
421
- evolved_codecs = tuple (c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs )
436
+ evolved_codecs = tuple (
437
+ c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs
438
+ )
422
439
if evolved_codecs != self .codecs :
423
440
return replace (self , codecs = evolved_codecs )
424
441
return self
@@ -469,7 +486,7 @@ async def _decode_single(
469
486
shape = shard_shape ,
470
487
dtype = shard_spec .dtype .to_native_dtype (),
471
488
order = shard_spec .order ,
472
- fill_value = 0 ,
489
+ fill_value = shard_spec . fill_value ,
473
490
)
474
491
shard_dict = await _ShardReader .from_bytes (shard_bytes , self , chunks_per_shard )
475
492
@@ -516,7 +533,7 @@ async def _decode_partial_single(
516
533
shape = indexer .shape ,
517
534
dtype = shard_spec .dtype .to_native_dtype (),
518
535
order = shard_spec .order ,
519
- fill_value = 0 ,
536
+ fill_value = shard_spec . fill_value ,
520
537
)
521
538
522
539
indexed_chunks = list (indexer )
@@ -593,7 +610,9 @@ async def _encode_single(
593
610
shard_array ,
594
611
)
595
612
596
- return await shard_builder .finalize (self .index_location , self ._encode_shard_index )
613
+ return await shard_builder .finalize (
614
+ self .index_location , self ._encode_shard_index
615
+ )
597
616
598
617
async def _encode_partial_single (
599
618
self ,
@@ -653,7 +672,8 @@ def _is_total_shard(
653
672
self , all_chunk_coords : set [ChunkCoords ], chunks_per_shard : ChunkCoords
654
673
) -> bool :
655
674
return len (all_chunk_coords ) == product (chunks_per_shard ) and all (
656
- chunk_coords in all_chunk_coords for chunk_coords in c_order_iter (chunks_per_shard )
675
+ chunk_coords in all_chunk_coords
676
+ for chunk_coords in c_order_iter (chunks_per_shard )
657
677
)
658
678
659
679
async def _decode_shard_index (
@@ -679,7 +699,9 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
679
699
.encode (
680
700
[
681
701
(
682
- get_ndbuffer_class ().from_numpy_array (index .offsets_and_lengths ),
702
+ get_ndbuffer_class ().from_numpy_array (
703
+ index .offsets_and_lengths
704
+ ),
683
705
self ._get_index_chunk_spec (index .chunks_per_shard ),
684
706
)
685
707
],
@@ -790,8 +812,8 @@ async def _load_partial_shard_maybe(
790
812
# Drop chunks where index lookup fails
791
813
if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
792
814
]
793
- if len (chunks ) == 0 :
794
- return {}
815
+ if len (chunks ) < len ( all_chunk_coords ) :
816
+ return None
795
817
796
818
groups = self ._coalesce_chunks (chunks )
797
819
@@ -803,6 +825,8 @@ async def _load_partial_shard_maybe(
803
825
804
826
shard_dict : ShardMutableMapping = {}
805
827
for d in shard_dicts :
828
+ if d is None :
829
+ return None
806
830
shard_dict .update (d )
807
831
808
832
return shard_dict
@@ -830,7 +854,9 @@ def _coalesce_chunks(
830
854
831
855
for chunk in sorted_chunks [1 :]:
832
856
gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
833
- size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
857
+ size_if_coalesced = (
858
+ chunk .byte_slice .stop - current_group [0 ].byte_slice .start
859
+ )
834
860
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
835
861
current_group .append (chunk )
836
862
else :
@@ -846,7 +872,7 @@ async def _get_group_bytes(
846
872
group : list [_ChunkCoordsByteSlice ],
847
873
byte_getter : ByteGetter ,
848
874
prototype : BufferPrototype ,
849
- ) -> ShardMapping :
875
+ ) -> ShardMapping | None :
850
876
"""
851
877
Reads a possibly coalesced group of one or more chunks from a shard.
852
878
Returns a mapping of chunk coordinates to bytes.
@@ -860,7 +886,7 @@ async def _get_group_bytes(
860
886
byte_range = RangeByteRequest (group_start , group_end ),
861
887
)
862
888
if group_bytes is None :
863
- return {}
889
+ return None
864
890
865
891
# Extract the bytes corresponding to each chunk in group from group_bytes.
866
892
shard_dict = {}
@@ -873,7 +899,9 @@ async def _get_group_bytes(
873
899
874
900
return shard_dict
875
901
876
- def compute_encoded_size (self , input_byte_length : int , shard_spec : ArraySpec ) -> int :
902
+ def compute_encoded_size (
903
+ self , input_byte_length : int , shard_spec : ArraySpec
904
+ ) -> int :
877
905
chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
878
906
return input_byte_length + self ._shard_index_size (chunks_per_shard )
879
907
0 commit comments