From 1359aec774387c58c0a715c2ea0909fbdb1286d8 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Wed, 30 Apr 2025 21:45:37 -0400 Subject: [PATCH 01/12] Add performance test of partial shard reads --- tests/test_codecs/test_sharding.py | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 403fd80e81..6efc3a251c 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -197,6 +197,71 @@ def test_sharding_partial_read( assert np.all(read_data == 1) +@pytest.mark.slow_hypothesis +@pytest.mark.parametrize("store", ["local"], indirect=["store"]) +def test_partial_shard_read_performance(store: Store) -> None: + import asyncio + import json + from functools import partial + from itertools import product + from timeit import timeit + from unittest.mock import AsyncMock + + # The whole test array is a single shard to keep runtime manageable while + # using a realistic shard size (256 MiB uncompressed, ~115 MiB compressed). + # In practice, the array is likely to be much larger with many shards of this + # rough order of magnitude. There are 512 chunks per shard in this example. + array_shape = (512, 512, 512) + shard_shape = (512, 512, 512) # 256 MiB uncompressed unit16s + chunk_shape = (64, 64, 64) # 512 KiB uncompressed unit16s + dtype = np.uint16 + + a = zarr.create_array( + StorePath(store), + shape=array_shape, + chunks=chunk_shape, + shards=shard_shape, + compressors=BloscCodec(cname="zstd"), + dtype=dtype, + fill_value=np.iinfo(dtype).max, + ) + # Narrow range of values lets zstd compress to about 1/2 of uncompressed size + a[:] = np.random.default_rng(123).integers(low=0, high=50, size=array_shape, dtype=dtype) + + num_calls = 20 + experiments = [] + for concurrency, get_latency, statement in product( + [1, 10, 100], [0.0, 0.01], ["a[0, :, :]", "a[:, 0, :]", "a[:, :, 0]"] + ): + zarr.config.set({"async.concurrency": concurrency}) + + async def get_with_latency(*args: Any, get_latency: float, **kwargs: Any) -> Any: + await asyncio.sleep(get_latency) + return await store.get(*args, **kwargs) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + store_mock.get.side_effect = partial(get_with_latency, get_latency=get_latency) + + a = zarr.open_array(StorePath(store_mock)) + + store_mock.reset_mock() + + # Each timeit call accesses a 512x512 slice covering 64 chunks + time = timeit(statement, number=num_calls, globals={"a": a}) / num_calls + experiments.append( + { + "concurrency": concurrency, + "statement": statement, + "get_latency": get_latency, + "time": time, + "store_get_calls": store_mock.get.call_count, + } + ) + + with open("zarr-python-partial-shard-read-performance-no-coalesce.json", "w") as f: + json.dump(experiments, f) + + @pytest.mark.parametrize( "array_fixture", [ From c7269944fde1bd9675771096992043fd0c95b715 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Sat, 19 Apr 2025 00:57:46 -0400 Subject: [PATCH 02/12] WIP Consolidate reads of multiple chunks in the same shard Add test and make max gap and max coalesce size config options Code clarity and comments Test that chunk request coalescing reduces calls to store Profile a few values for coalesce_max_gap Update [doc]tests to include new sharding.read.* values document sharded read config options in user-guide/config.rst tweak logic: start new coalesced group if coalescing would exceed `coalesce_max_bytes` previous logic only started a new group if existing group was size already exceeded coalesce_max_bytes. set `mypy_path = "src"` to help pre-commit mypy find imported classes Reorder methods in sharding.py, add docstring + commenting wording docs fix docstring clarification trigger precommit on all python files changed in this pull request trying to get the ruff format that's happening locally during pre-commit to match the pre-commit run that is failing on CI. revert trigger for pre-commit ruff format --- docs/user-guide/config.rst | 8 +- pyproject.toml | 1 + src/zarr/codecs/sharding.py | 169 ++++++++++++++++++++++++----- src/zarr/core/config.py | 6 + tests/test_codecs/test_sharding.py | 88 ++++++++++++++- tests/test_config.py | 8 ++ 6 files changed, 246 insertions(+), 34 deletions(-) diff --git a/docs/user-guide/config.rst b/docs/user-guide/config.rst index 5a9d26f2b9..06b1e79473 100644 --- a/docs/user-guide/config.rst +++ b/docs/user-guide/config.rst @@ -33,6 +33,10 @@ Configuration options include the following: - Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers`` - Selections of implementations of codecs, codec pipelines and buffers - Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more. +- Tuning reads from sharded zarrs. When reading less than a complete shard, reads of nearby chunks + within the same shard will be combined into a single request if they are less than + ``sharding.read.coalesce_max_gap_bytes`` apart and the combined request size is less than + ``sharding.read.coalesce_max_bytes``. For selecting custom implementations of codecs, pipelines, buffers and ndbuffers, first register the implementations in the registry and then select them in the config. @@ -79,4 +83,6 @@ This is the current default configuration:: 'default_zarr_format': 3, 'json_indent': 2, 'ndbuffer': 'zarr.buffer.cpu.NDBuffer', - 'threading': {'max_workers': None}} + 'sharding': {'read': {'coalesce_max_bytes': 104857600, + 'coalesce_max_gap_bytes': 1048576}}, + 'threading': {'max_workers': None}} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a48a5eea25..cd5de2115e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -354,6 +354,7 @@ ignore = [ [tool.mypy] python_version = "3.11" ignore_missing_imports = true +mypy_path = "src" namespace_packages = false strict = true diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index cd8676b4d1..e7499f726f 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -38,11 +38,13 @@ from zarr.core.common import ( ChunkCoords, ChunkCoordsLike, + concurrent_map, parse_enum, parse_named_configuration, parse_shapelike, product, ) +from zarr.core.config import config from zarr.core.dtype.npy.int import UInt64 from zarr.core.indexing import ( BasicIndexer, @@ -198,7 +200,9 @@ async def from_bytes( @classmethod def create_empty( - cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None + cls, + chunks_per_shard: ChunkCoords, + buffer_prototype: BufferPrototype | None = None, ) -> _ShardReader: if buffer_prototype is None: buffer_prototype = default_buffer_prototype() @@ -248,7 +252,9 @@ def merge_with_morton_order( @classmethod def create_empty( - cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None + cls, + chunks_per_shard: ChunkCoords, + buffer_prototype: BufferPrototype | None = None, ) -> _ShardBuilder: if buffer_prototype is None: buffer_prototype = default_buffer_prototype() @@ -329,9 +335,18 @@ async def finalize( return await shard_builder.finalize(index_location, index_encoder) +class _ChunkCoordsByteSlice(NamedTuple): + """Holds a chunk's coordinates and its byte range in a serialized shard.""" + + coords: ChunkCoords + byte_slice: slice + + @dataclass(frozen=True) class ShardingCodec( - ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin + ArrayBytesCodec, + ArrayBytesCodecPartialDecodeMixin, + ArrayBytesCodecPartialEncodeMixin, ): chunk_shape: ChunkCoords codecs: tuple[Codec, ...] @@ -508,32 +523,21 @@ async def _decode_partial_single( all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} # reading bytes of all requested chunks - shard_dict: ShardMapping = {} + shard_dict_maybe: ShardMapping | None = {} if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard shard_dict_maybe = await self._load_full_shard_maybe( - byte_getter=byte_getter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, + byte_getter, chunk_spec.prototype, chunks_per_shard ) - if shard_dict_maybe is None: - return None - shard_dict = shard_dict_maybe else: # read some chunks within the shard - shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) - if shard_index is None: - return None - shard_dict = {} - for chunk_coords in all_chunk_coords: - chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) - if chunk_byte_slice: - chunk_bytes = await byte_getter.get( - prototype=chunk_spec.prototype, - byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), - ) - if chunk_bytes: - shard_dict[chunk_coords] = chunk_bytes + shard_dict_maybe = await self._load_partial_shard_maybe( + byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords + ) + + if shard_dict_maybe is None: + return None + shard_dict = shard_dict_maybe # decoding chunks and writing them into the output buffer await self.codec_pipeline.read( @@ -615,7 +619,9 @@ async def _encode_partial_single( indexer = list( get_indexer( - selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) + selection, + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) ) @@ -689,7 +695,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: get_pipeline_class() .from_codecs(self.index_codecs) .compute_encoded_size( - 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + 16 * product(chunks_per_shard), + self._get_index_chunk_spec(chunks_per_shard), ) ) @@ -734,7 +741,8 @@ async def _load_shard_index_maybe( ) else: index_bytes = await byte_getter.get( - prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) + prototype=numpy_buffer_prototype(), + byte_range=SuffixByteRequest(shard_index_size), ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) @@ -748,7 +756,10 @@ async def _load_shard_index( ) or _ShardIndex.create_empty(chunks_per_shard) async def _load_full_shard_maybe( - self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords + self, + byte_getter: ByteGetter, + prototype: BufferPrototype, + chunks_per_shard: ChunkCoords, ) -> _ShardReader | None: shard_bytes = await byte_getter.get(prototype=prototype) @@ -758,6 +769,110 @@ async def _load_full_shard_maybe( else None ) + async def _load_partial_shard_maybe( + self, + byte_getter: ByteGetter, + prototype: BufferPrototype, + chunks_per_shard: ChunkCoords, + all_chunk_coords: set[ChunkCoords], + ) -> ShardMapping | None: + """ + Read chunks from `byte_getter` for the case where the read is less than a full shard. + Returns a mapping of chunk coordinates to bytes. + """ + shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) + if shard_index is None: + return None + + chunks = [ + _ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) + for chunk_coords in all_chunk_coords + # Drop chunks where index lookup fails + if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) + ] + if len(chunks) == 0: + return {} + + groups = self._coalesce_chunks(chunks) + + shard_dicts = await concurrent_map( + [(group, byte_getter, prototype) for group in groups], + self._get_group_bytes, + config.get("async.concurrency"), + ) + + shard_dict: ShardMutableMapping = {} + for d in shard_dicts: + shard_dict.update(d) + + return shard_dict + + def _coalesce_chunks( + self, + chunks: list[_ChunkCoordsByteSlice], + ) -> list[list[_ChunkCoordsByteSlice]]: + """ + Combine chunks from a single shard into groups that should be read together + in a single request. + + Respects the following configuration options: + - `sharding.read.coalesce_max_gap_bytes`: The maximum gap between + chunks to coalesce into a single group. + - `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group. + """ + max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes") + coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes") + + sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) + + groups = [] + current_group = [sorted_chunks[0]] + + for chunk in sorted_chunks[1:]: + gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop + size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start + if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: + current_group.append(chunk) + else: + groups.append(current_group) + current_group = [chunk] + + groups.append(current_group) + + return groups + + async def _get_group_bytes( + self, + group: list[_ChunkCoordsByteSlice], + byte_getter: ByteGetter, + prototype: BufferPrototype, + ) -> ShardMapping: + """ + Reads a possibly coalesced group of one or more chunks from a shard. + Returns a mapping of chunk coordinates to bytes. + """ + group_start = group[0].byte_slice.start + group_end = group[-1].byte_slice.stop + + # A single call to retrieve the bytes for the entire group. + group_bytes = await byte_getter.get( + prototype=prototype, + byte_range=RangeByteRequest(group_start, group_end), + ) + if group_bytes is None: + return {} + + # Extract the bytes corresponding to each chunk in group from group_bytes. + shard_dict = {} + for chunk in group: + chunk_slice = slice( + chunk.byte_slice.start - group_start, + chunk.byte_slice.stop - group_start, + ) + shard_dict[chunk.coords] = group_bytes[chunk_slice] + + return shard_dict + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 05d048ef74..993eaa919d 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -111,6 +111,12 @@ def enable_gpu(self) -> ConfigSet: }, "async": {"concurrency": 10, "timeout": None}, "threading": {"max_workers": None}, + "sharding": { + "read": { + "coalesce_max_bytes": 100 * 2**20, # 100MiB + "coalesce_max_gap_bytes": 2**20, # 1MiB + } + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 6efc3a251c..cb14ee97dc 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -1,5 +1,6 @@ import pickle from typing import Any +from unittest.mock import AsyncMock import numpy as np import numpy.typing as npt @@ -9,7 +10,7 @@ import zarr.api import zarr.api.asynchronous from zarr import Array -from zarr.abc.store import Store +from zarr.abc.store import RangeByteRequest, Store, SuffixByteRequest from zarr.codecs import ( BloscCodec, ShardingCodec, @@ -197,6 +198,7 @@ def test_sharding_partial_read( assert np.all(read_data == 1) +@pytest.mark.skip("This is profiling rather than a test") @pytest.mark.slow_hypothesis @pytest.mark.parametrize("store", ["local"], indirect=["store"]) def test_partial_shard_read_performance(store: Store) -> None: @@ -230,10 +232,18 @@ def test_partial_shard_read_performance(store: Store) -> None: num_calls = 20 experiments = [] - for concurrency, get_latency, statement in product( - [1, 10, 100], [0.0, 0.01], ["a[0, :, :]", "a[:, 0, :]", "a[:, :, 0]"] + for concurrency, get_latency, coalesce_max_gap, statement in product( + [1, 10, 100], + [0.0, 0.01], + [-1, 2**20, 10 * 2**20], + ["a[0, :, :]", "a[:, 0, :]", "a[:, :, 0]"], ): - zarr.config.set({"async.concurrency": concurrency}) + zarr.config.set( + { + "async.concurrency": concurrency, + "sharding.read.coalesce_max_gap_bytes": coalesce_max_gap, + } + ) async def get_with_latency(*args: Any, get_latency: float, **kwargs: Any) -> Any: await asyncio.sleep(get_latency) @@ -251,17 +261,83 @@ async def get_with_latency(*args: Any, get_latency: float, **kwargs: Any) -> Any experiments.append( { "concurrency": concurrency, - "statement": statement, + "coalesce_max_gap": coalesce_max_gap, "get_latency": get_latency, + "statement": statement, "time": time, "store_get_calls": store_mock.get.call_count, } ) - with open("zarr-python-partial-shard-read-performance-no-coalesce.json", "w") as f: + with open("zarr-python-partial-shard-read-performance-with-coalesce.json", "w") as f: json.dump(experiments, f) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("coalesce_reads", [True, False]) +def test_sharding_multiple_chunks_partial_shard_read( + store: Store, index_location: ShardingCodecIndexLocation, coalesce_reads: bool +) -> None: + array_shape = (16, 64) + shard_shape = (8, 32) + chunk_shape = (2, 4) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + if coalesce_reads: + # 1MiB, enough to coalesce all chunks within a shard in this example + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20}) + else: + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) # disable coalescing + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Reads 3 (2 full, 1 partial) chunks each from 2 shards (a subset of both shards) + # for a total of 6 chunks accessed + assert np.allclose(a[0, 22:42], np.arange(22, 42, dtype="float32")) + + if coalesce_reads: + # 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard) + assert store_mock.get.call_count == 4 + else: + # 2 shard index requests + 6 chunks + assert store_mock.get.call_count == 8 + + for method, args, kwargs in store_mock.method_calls: + assert method == "get" + assert args[0].startswith("c/") # get from a chunk + assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) + + store_mock.reset_mock() + + # Reads 4 chunks from both shards along dimension 0 for a total of 8 chunks accessed + assert np.allclose(a[:, 0], np.arange(0, data.size, array_shape[1], dtype="float32")) + + if coalesce_reads: + # 2 shard index requests + 2 coalesced chunk data byte ranges (one for each shard) + assert store_mock.get.call_count == 4 + else: + # 2 shard index requests + 8 chunks + assert store_mock.get.call_count == 10 + + for method, args, kwargs in store_mock.method_calls: + assert method == "get" + assert args[0].startswith("c/") # get from a chunk + assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) + + @pytest.mark.parametrize( "array_fixture", [ diff --git a/tests/test_config.py b/tests/test_config.py index e267601272..0c941ee62e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -103,6 +103,12 @@ def test_config_defaults_set() -> None: }, "buffer": "zarr.buffer.cpu.Buffer", "ndbuffer": "zarr.buffer.cpu.NDBuffer", + "sharding": { + "read": { + "coalesce_max_bytes": 100 * 2**20, # 100 MiB + "coalesce_max_gap_bytes": 2**20, # 1 MiB + } + }, } ] ) @@ -111,6 +117,8 @@ def test_config_defaults_set() -> None: assert config.get("async.timeout") is None assert config.get("codec_pipeline.batch_size") == 1 assert config.get("json_indent") == 2 + assert config.get("sharding.read.coalesce_max_bytes") == 100 * 2**20 # 100 MiB + assert config.get("sharding.read.coalesce_max_gap_bytes") == 2**20 # 1 MiB @pytest.mark.parametrize( From 44d9ce4d8c655588ee12beb79085ac809fc7cfcb Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 2 Jun 2025 14:16:05 -0400 Subject: [PATCH 03/12] Add changes/3004.feature.rst --- changes/3004.feature.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changes/3004.feature.rst diff --git a/changes/3004.feature.rst b/changes/3004.feature.rst new file mode 100644 index 0000000000..b15a5ec943 --- /dev/null +++ b/changes/3004.feature.rst @@ -0,0 +1,3 @@ +Optimizes reading more than one, but not all, chunks from a shard. Chunks are now read in parallel +and reads of nearby chunks within the same shard are combined to reduce the number of calls to the store. +See :ref:`user-guide-config` for more details. \ No newline at end of file From 009ce6a1abfd9d1d6b32bd51c2e7dde4ecbaa592 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 2 Jun 2025 22:10:56 -0400 Subject: [PATCH 04/12] Consistently return None on failure and test partial shard read failure modes Use range of integers as out_selection not slice in CoordinateIndexer To fix issue when using vindex with repeated indexes in indexer test: improve formatting and add debugging breakpoint in array property tests test: disable hypothesis deadline for test_array_roundtrip to prevent timeout fix: initialize decode buffers with shard_spec.fill_value instead of 0 to fix partial shard holes style: reformat code for improved readability and consistency in sharding.py fix: revert incorrect RangeByteRequest length fix in sharding byte retrieval --- src/zarr/codecs/sharding.py | 72 +++++++++++----- src/zarr/core/indexing.py | 2 +- tests/test_codecs/test_sharding.py | 128 +++++++++++++++++++++++++++-- tests/test_properties.py | 49 +++++++---- 4 files changed, 206 insertions(+), 45 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index e7499f726f..ec4fe476f6 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -90,9 +90,9 @@ async def get( self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" - assert prototype == default_buffer_prototype(), ( - f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" - ) + assert ( + prototype == default_buffer_prototype() + ), f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" return self.shard_dict.get(self.chunk_coords) @@ -124,7 +124,9 @@ def chunks_per_shard(self) -> ChunkCoords: def _localize_chunk(self, chunk_coords: ChunkCoords) -> ChunkCoords: return tuple( chunk_i % shard_i - for chunk_i, shard_i in zip(chunk_coords, self.offsets_and_lengths.shape, strict=False) + for chunk_i, shard_i in zip( + chunk_coords, self.offsets_and_lengths.shape, strict=False + ) ) def is_all_empty(self) -> bool: @@ -141,7 +143,9 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None: else: return (int(chunk_start), int(chunk_start + chunk_len)) - def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None: + def set_chunk_slice( + self, chunk_coords: ChunkCoords, chunk_slice: slice | None + ) -> None: localized_chunk = self._localize_chunk(chunk_coords) if chunk_slice is None: self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64) @@ -163,7 +167,11 @@ def is_dense(self, chunk_byte_length: int) -> bool: # Are all non-empty offsets unique? if len( - {offset for offset, _ in sorted_offsets_and_lengths if offset != MAX_UINT_64} + { + offset + for offset, _ in sorted_offsets_and_lengths + if offset != MAX_UINT_64 + } ) != len(sorted_offsets_and_lengths): return False @@ -267,7 +275,9 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None: chunk_start = len(self.buf) chunk_length = len(value) self.buf += value - self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) + self.index.set_chunk_slice( + chunk_coords, slice(chunk_start, chunk_start + chunk_length) + ) def __delitem__(self, chunk_coords: ChunkCoords) -> None: raise NotImplementedError @@ -281,7 +291,9 @@ async def finalize( if index_location == ShardingCodecIndexLocation.start: empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64 self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) - index_bytes = await index_encoder(self.index) # encode again with corrected offsets + index_bytes = await index_encoder( + self.index + ) # encode again with corrected offsets out_buf = index_bytes + self.buf else: out_buf = self.buf + index_bytes @@ -359,7 +371,8 @@ def __init__( chunk_shape: ChunkCoordsLike, codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),), index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()), - index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end, + index_location: ShardingCodecIndexLocation + | str = ShardingCodecIndexLocation.end, ) -> None: chunk_shape_parsed = parse_shapelike(chunk_shape) codecs_parsed = parse_codecs(codecs) @@ -389,7 +402,9 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"])) object.__setattr__(self, "codecs", parse_codecs(config["codecs"])) object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"])) - object.__setattr__(self, "index_location", parse_index_location(config["index_location"])) + object.__setattr__( + self, "index_location", parse_index_location(config["index_location"]) + ) # Use instance-local lru_cache to avoid memory leaks # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) @@ -418,7 +433,9 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: shard_spec = self._get_chunk_spec(array_spec) - evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs) + evolved_codecs = tuple( + c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs + ) if evolved_codecs != self.codecs: return replace(self, codecs=evolved_codecs) return self @@ -469,7 +486,7 @@ async def _decode_single( shape=shard_shape, dtype=shard_spec.dtype.to_native_dtype(), order=shard_spec.order, - fill_value=0, + fill_value=shard_spec.fill_value, ) shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) @@ -516,7 +533,7 @@ async def _decode_partial_single( shape=indexer.shape, dtype=shard_spec.dtype.to_native_dtype(), order=shard_spec.order, - fill_value=0, + fill_value=shard_spec.fill_value, ) indexed_chunks = list(indexer) @@ -593,7 +610,9 @@ async def _encode_single( shard_array, ) - return await shard_builder.finalize(self.index_location, self._encode_shard_index) + return await shard_builder.finalize( + self.index_location, self._encode_shard_index + ) async def _encode_partial_single( self, @@ -653,7 +672,8 @@ def _is_total_shard( self, all_chunk_coords: set[ChunkCoords], chunks_per_shard: ChunkCoords ) -> bool: return len(all_chunk_coords) == product(chunks_per_shard) and all( - chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) + chunk_coords in all_chunk_coords + for chunk_coords in c_order_iter(chunks_per_shard) ) async def _decode_shard_index( @@ -679,7 +699,9 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: .encode( [ ( - get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), + get_ndbuffer_class().from_numpy_array( + index.offsets_and_lengths + ), self._get_index_chunk_spec(index.chunks_per_shard), ) ], @@ -790,8 +812,8 @@ async def _load_partial_shard_maybe( # Drop chunks where index lookup fails if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) ] - if len(chunks) == 0: - return {} + if len(chunks) < len(all_chunk_coords): + return None groups = self._coalesce_chunks(chunks) @@ -803,6 +825,8 @@ async def _load_partial_shard_maybe( shard_dict: ShardMutableMapping = {} for d in shard_dicts: + if d is None: + return None shard_dict.update(d) return shard_dict @@ -830,7 +854,9 @@ def _coalesce_chunks( for chunk in sorted_chunks[1:]: gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop - size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start + size_if_coalesced = ( + chunk.byte_slice.stop - current_group[0].byte_slice.start + ) if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: current_group.append(chunk) else: @@ -846,7 +872,7 @@ async def _get_group_bytes( group: list[_ChunkCoordsByteSlice], byte_getter: ByteGetter, prototype: BufferPrototype, - ) -> ShardMapping: + ) -> ShardMapping | None: """ Reads a possibly coalesced group of one or more chunks from a shard. Returns a mapping of chunk coordinates to bytes. @@ -860,7 +886,7 @@ async def _get_group_bytes( byte_range=RangeByteRequest(group_start, group_end), ) if group_bytes is None: - return {} + return None # Extract the bytes corresponding to each chunk in group from group_bytes. shard_dict = {} @@ -873,7 +899,9 @@ async def _get_group_bytes( return shard_dict - def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: + def compute_encoded_size( + self, input_byte_length: int, shard_spec: ArraySpec + ) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index c11889f7f4..0e0bb664d8 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1193,7 +1193,7 @@ def __iter__(self) -> Iterator[ChunkProjection]: stop = self.chunk_nitems_cumsum[chunk_rix] out_selection: slice | npt.NDArray[np.intp] if self.sel_sort is None: - out_selection = slice(start, stop) + out_selection = np.arange(start, stop) else: out_selection = self.sel_sort[start:stop] diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index cb14ee97dc..dbe64a32d5 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -111,7 +111,9 @@ def test_sharding_scalar( indirect=["array_fixture"], ) def test_sharding_partial( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -147,7 +149,9 @@ def test_sharding_partial( indirect=["array_fixture"], ) def test_sharding_partial_readwrite( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -179,7 +183,9 @@ def test_sharding_partial_readwrite( @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_partial_read( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture spath = StorePath(store) @@ -338,6 +344,114 @@ def test_sharding_multiple_chunks_partial_shard_read( assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load the bytes of the shard's chunk index fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + # loading the index is the first call to .get() so returning None will simulate an index load failure + store_mock.get.return_value = None + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__index_chunk_slice_fails( + store: Store, + index_location: ShardingCodecIndexLocation, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test fill value is returned when looking up a chunk's byte slice within a shard fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + monkeypatch.setattr( + "zarr.codecs.sharding._ShardIndex.get_chunk_slice", + lambda self, chunk_coords: None, + ) + + a = zarr.create_array( + StorePath(store), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_partial_shard_read__chunk_load_fails( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """Test fill value is returned when the call to the store to load a chunk's bytes fails.""" + array_shape = (16,) + shard_shape = (16,) + chunk_shape = (8,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + fill_value = -999 + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=fill_value, + ) + a[:] = data + + # Set up store mock after array creation to only modify calls during array indexing + # Succeed on first call (index load), fail on subsequent calls (chunk loads) + async def first_success_then_fail(*args: Any, **kwargs: Any) -> Any: + if store_mock.get.call_count == 1: + return await store.get(*args, **kwargs) + else: + return None + + store_mock.get.reset_mock() + store_mock.get.side_effect = first_success_then_fail + + # Read from one of two chunks in a shard to test the partial shard read path + assert a[0] == fill_value + assert a[0] != data[0] + + @pytest.mark.parametrize( "array_fixture", [ @@ -348,7 +462,9 @@ def test_sharding_multiple_chunks_partial_shard_read( @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_partial_overwrite( - store: Store, array_fixture: npt.NDArray[Any], index_location: ShardingCodecIndexLocation + store: Store, + array_fixture: npt.NDArray[Any], + index_location: ShardingCodecIndexLocation, ) -> None: data = array_fixture[:10, :10, :10] spath = StorePath(store) @@ -578,7 +694,9 @@ async def test_sharding_with_empty_inner_chunk( ) @pytest.mark.parametrize("chunks_per_shard", [(5, 2), (2, 5), (5, 5)]) async def test_sharding_with_chunks_per_shard( - store: Store, index_location: ShardingCodecIndexLocation, chunks_per_shard: tuple[int] + store: Store, + index_location: ShardingCodecIndexLocation, + chunks_per_shard: tuple[int], ) -> None: chunk_shape = (2, 1) shape = tuple(x * y for x, y in zip(chunks_per_shard, chunk_shape, strict=False)) diff --git a/tests/test_properties.py b/tests/test_properties.py index b8d50ef0b1..de302c56b0 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -79,8 +79,14 @@ def deep_equal(a: Any, b: Any) -> bool: @given(data=st.data(), zarr_format=zarr_formats) def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) - zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format))) - assert_array_equal(nparray, zarray[:]) + zarray = data.draw( + arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)) + ) + try: + assert_array_equal(nparray, zarray[:]) + except Exception as e: + breakpoint() + raise e @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @@ -92,12 +98,20 @@ def test_array_creates_implicit_groups(array): parent = "/".join(ancestry[: i + 1]) if array.metadata.zarr_format == 2: assert ( - sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype())) + sync( + array.store.get( + f"{parent}/.zgroup", prototype=default_buffer_prototype() + ) + ) is not None ) elif array.metadata.zarr_format == 3: assert ( - sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype())) + sync( + array.store.get( + f"{parent}/zarr.json", prototype=default_buffer_prototype() + ) + ) is not None ) @@ -115,7 +129,9 @@ def test_basic_indexing(data: st.DataObject) -> None: actual = zarray[indexer] assert_array_equal(nparray[indexer], actual) - new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)) + new_data = data.draw( + numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype) + ) zarray[indexer] = new_data nparray[indexer] = new_data assert_array_equal(nparray, zarray[:]) @@ -137,7 +153,9 @@ def test_oindex(data: st.DataObject) -> None: if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size: # behaviour of setitem with repeated indices is not guaranteed in practice assume(False) - new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)) + new_data = data.draw( + numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype) + ) nparray[npindexer] = new_data zarray.oindex[zindexer] = new_data assert_array_equal(nparray, zarray[:]) @@ -152,20 +170,13 @@ def test_vindex(data: st.DataObject) -> None: indexer = data.draw( npst.integer_array_indices( - shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None) + shape=nparray.shape, + result_shape=npst.array_shapes(min_side=1, max_dims=None), ) ) actual = zarray.vindex[indexer] assert_array_equal(nparray[indexer], actual) - # FIXME! - # when the indexer is such that a value gets overwritten multiple times, - # I think the output depends on chunking. - # new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype)) - # nparray[indexer] = new_data - # zarray.vindex[indexer] = new_data - # assert_array_equal(nparray, zarray[:]) - @given(store=stores, meta=array_metadata()) # type: ignore[misc] @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @@ -220,7 +231,9 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in orig = metadata.to_dict() rt = metadata_roundtripped.to_dict() - assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" + assert deep_equal( + orig, rt + ), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" # @st.composite @@ -320,7 +333,9 @@ def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> N # version-specific validations if isinstance(meta, ArrayV2Metadata): assert asdict_dict["filters"] != () - assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple) + assert asdict_dict["filters"] is None or isinstance( + asdict_dict["filters"], tuple + ) assert asdict_dict["zarr_format"] == 2 else: assert asdict_dict["zarr_format"] == 3 From c65cf828eef63a563ada3ea7f1797c0f8f7b4439 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 16:47:46 -0400 Subject: [PATCH 05/12] Fix and test for case where some chunks in shard are all fill --- src/zarr/codecs/sharding.py | 61 +++++++----------------- tests/test_codecs/test_sharding.py | 76 ++++++++++++++++++++++++++++-- tests/test_properties.py | 68 +++++--------------------- 3 files changed, 102 insertions(+), 103 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index ec4fe476f6..8b64e68130 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -90,9 +90,9 @@ async def get( self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" - assert ( - prototype == default_buffer_prototype() - ), f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" + assert prototype == default_buffer_prototype(), ( + f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" + ) return self.shard_dict.get(self.chunk_coords) @@ -124,9 +124,7 @@ def chunks_per_shard(self) -> ChunkCoords: def _localize_chunk(self, chunk_coords: ChunkCoords) -> ChunkCoords: return tuple( chunk_i % shard_i - for chunk_i, shard_i in zip( - chunk_coords, self.offsets_and_lengths.shape, strict=False - ) + for chunk_i, shard_i in zip(chunk_coords, self.offsets_and_lengths.shape, strict=False) ) def is_all_empty(self) -> bool: @@ -143,9 +141,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None: else: return (int(chunk_start), int(chunk_start + chunk_len)) - def set_chunk_slice( - self, chunk_coords: ChunkCoords, chunk_slice: slice | None - ) -> None: + def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None: localized_chunk = self._localize_chunk(chunk_coords) if chunk_slice is None: self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64) @@ -167,11 +163,7 @@ def is_dense(self, chunk_byte_length: int) -> bool: # Are all non-empty offsets unique? if len( - { - offset - for offset, _ in sorted_offsets_and_lengths - if offset != MAX_UINT_64 - } + {offset for offset, _ in sorted_offsets_and_lengths if offset != MAX_UINT_64} ) != len(sorted_offsets_and_lengths): return False @@ -275,9 +267,7 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None: chunk_start = len(self.buf) chunk_length = len(value) self.buf += value - self.index.set_chunk_slice( - chunk_coords, slice(chunk_start, chunk_start + chunk_length) - ) + self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) def __delitem__(self, chunk_coords: ChunkCoords) -> None: raise NotImplementedError @@ -291,9 +281,7 @@ async def finalize( if index_location == ShardingCodecIndexLocation.start: empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64 self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) - index_bytes = await index_encoder( - self.index - ) # encode again with corrected offsets + index_bytes = await index_encoder(self.index) # encode again with corrected offsets out_buf = index_bytes + self.buf else: out_buf = self.buf + index_bytes @@ -371,8 +359,7 @@ def __init__( chunk_shape: ChunkCoordsLike, codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),), index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()), - index_location: ShardingCodecIndexLocation - | str = ShardingCodecIndexLocation.end, + index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end, ) -> None: chunk_shape_parsed = parse_shapelike(chunk_shape) codecs_parsed = parse_codecs(codecs) @@ -402,9 +389,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"])) object.__setattr__(self, "codecs", parse_codecs(config["codecs"])) object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"])) - object.__setattr__( - self, "index_location", parse_index_location(config["index_location"]) - ) + object.__setattr__(self, "index_location", parse_index_location(config["index_location"])) # Use instance-local lru_cache to avoid memory leaks # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) @@ -433,9 +418,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: shard_spec = self._get_chunk_spec(array_spec) - evolved_codecs = tuple( - c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs - ) + evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs) if evolved_codecs != self.codecs: return replace(self, codecs=evolved_codecs) return self @@ -610,9 +593,7 @@ async def _encode_single( shard_array, ) - return await shard_builder.finalize( - self.index_location, self._encode_shard_index - ) + return await shard_builder.finalize(self.index_location, self._encode_shard_index) async def _encode_partial_single( self, @@ -672,8 +653,7 @@ def _is_total_shard( self, all_chunk_coords: set[ChunkCoords], chunks_per_shard: ChunkCoords ) -> bool: return len(all_chunk_coords) == product(chunks_per_shard) and all( - chunk_coords in all_chunk_coords - for chunk_coords in c_order_iter(chunks_per_shard) + chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) ) async def _decode_shard_index( @@ -699,9 +679,7 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: .encode( [ ( - get_ndbuffer_class().from_numpy_array( - index.offsets_and_lengths - ), + get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), self._get_index_chunk_spec(index.chunks_per_shard), ) ], @@ -810,9 +788,10 @@ async def _load_partial_shard_maybe( _ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) for chunk_coords in all_chunk_coords # Drop chunks where index lookup fails + # e.g. when write_empty_chunks = False and the chunk is empty if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) ] - if len(chunks) < len(all_chunk_coords): + if len(chunks) == 0: return None groups = self._coalesce_chunks(chunks) @@ -854,9 +833,7 @@ def _coalesce_chunks( for chunk in sorted_chunks[1:]: gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop - size_if_coalesced = ( - chunk.byte_slice.stop - current_group[0].byte_slice.start - ) + size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: current_group.append(chunk) else: @@ -899,9 +876,7 @@ async def _get_group_bytes( return shard_dict - def compute_encoded_size( - self, input_byte_length: int, shard_spec: ArraySpec - ) -> int: + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: chunks_per_shard = self._get_chunks_per_shard(shard_spec) return input_byte_length + self._shard_index_size(chunks_per_shard) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index dbe64a32d5..35940feb47 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -344,6 +344,79 @@ def test_sharding_multiple_chunks_partial_shard_read( assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - some, but not all, chunks in the last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have less entries than chunks in the shard + """ + # array with mixed empty and non-empty chunks in second shard + data = np.array([ + # shard 0. full 8 elements, all chunks have some non-fill data + 0, 1, 2, 3, 4, 5, 6, 7, + # shard 1. 6 elements (< shard shape) + 2, 0, # chunk 0, written + 0, 0, # chunk 1, all fill, not written + 4, 5 # chunk 2, written + ], dtype="int32") # fmt: off + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (8,), "index_location": index_location}, + dtype="int32", + fill_value=0, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +def test_sharding_read_empty_chunks_within_empty_shard_write_empty_false( + store: Store, index_location: ShardingCodecIndexLocation +) -> None: + """ + Case where + - all chunks in last shard are empty + - the last shard is not complete (array length is not a multiple of shard shape), + this takes us down the partial shard read path + - write_empty_chunks=False so the shard index will have no entries + """ + fill_value = -99 + shard_size = 8 + data = np.arange(14, dtype="int32") + data[shard_size:] = fill_value # 2nd shard is all fill value + + spath = StorePath(store) + a = zarr.create_array( + spath, + shape=(14,), + chunks=(2,), + shards={"shape": (shard_size,), "index_location": index_location}, + dtype="int32", + fill_value=fill_value, + filters=None, + compressors=None, + config={"write_empty_chunks": False}, + ) + a[:] = data + + assert np.array_equal(a[:], data) + + @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_partial_shard_read__index_load_fails( @@ -577,7 +650,6 @@ def test_nested_sharding_create_array( filters=None, compressors=None, ) - print(a.metadata.to_dict()) a[:, :, :] = data @@ -637,7 +709,6 @@ async def test_delete_empty_shards(store: Store) -> None: compressors=None, fill_value=1, ) - print(a.metadata.to_dict()) await _AsyncArrayProxy(a)[:, :].set(np.zeros((16, 16))) await _AsyncArrayProxy(a)[8:, :].set(np.ones((8, 16))) await _AsyncArrayProxy(a)[:, 8:].set(np.ones((16, 8))) @@ -682,7 +753,6 @@ async def test_sharding_with_empty_inner_chunk( ) data[:4, :4] = fill_value await a.setitem(..., data) - print("read data") data_read = await a.getitem(...) assert np.array_equal(data_read, data) diff --git a/tests/test_properties.py b/tests/test_properties.py index de302c56b0..e941250872 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -76,17 +76,11 @@ def deep_equal(a: Any, b: Any) -> bool: @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") -@given(data=st.data(), zarr_format=zarr_formats) -def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: - nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) - zarray = data.draw( - arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)) - ) - try: - assert_array_equal(nparray, zarray[:]) - except Exception as e: - breakpoint() - raise e +@given(data=st.data()) +def test_array_roundtrip(data: st.DataObject) -> None: + nparray = data.draw(numpy_arrays()) + zarray = data.draw(arrays(arrays=st.just(nparray))) + assert_array_equal(nparray, zarray[:]) @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @@ -98,20 +92,12 @@ def test_array_creates_implicit_groups(array): parent = "/".join(ancestry[: i + 1]) if array.metadata.zarr_format == 2: assert ( - sync( - array.store.get( - f"{parent}/.zgroup", prototype=default_buffer_prototype() - ) - ) + sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype())) is not None ) elif array.metadata.zarr_format == 3: assert ( - sync( - array.store.get( - f"{parent}/zarr.json", prototype=default_buffer_prototype() - ) - ) + sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype())) is not None ) @@ -129,9 +115,7 @@ def test_basic_indexing(data: st.DataObject) -> None: actual = zarray[indexer] assert_array_equal(nparray[indexer], actual) - new_data = data.draw( - numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype) - ) + new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)) zarray[indexer] = new_data nparray[indexer] = new_data assert_array_equal(nparray, zarray[:]) @@ -153,9 +137,7 @@ def test_oindex(data: st.DataObject) -> None: if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size: # behaviour of setitem with repeated indices is not guaranteed in practice assume(False) - new_data = data.draw( - numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype) - ) + new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)) nparray[npindexer] = new_data zarray.oindex[zindexer] = new_data assert_array_equal(nparray, zarray[:]) @@ -231,33 +213,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in orig = metadata.to_dict() rt = metadata_roundtripped.to_dict() - assert deep_equal( - orig, rt - ), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" - - -# @st.composite -# def advanced_indices(draw, *, shape): -# basic_idxr = draw( -# basic_indices( -# shape=shape, min_dims=len(shape), max_dims=len(shape), allow_ellipsis=False -# ).filter(lambda x: isinstance(x, tuple)) -# ) - -# int_idxr = draw( -# npst.integer_array_indices(shape=shape, result_shape=npst.array_shapes(max_dims=1)) -# ) -# args = tuple( -# st.sampled_from((l, r)) for l, r in zip_longest(basic_idxr, int_idxr, fillvalue=slice(None)) -# ) -# return draw(st.tuples(*args)) - - -# @given(st.data()) -# def test_roundtrip_object_array(data): -# nparray = data.draw(np_arrays) -# zarray = data.draw(arrays(arrays=st.just(nparray))) -# assert_array_equal(nparray, zarray[:]) + assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" def serialized_complex_float_is_valid( @@ -333,9 +289,7 @@ def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> N # version-specific validations if isinstance(meta, ArrayV2Metadata): assert asdict_dict["filters"] != () - assert asdict_dict["filters"] is None or isinstance( - asdict_dict["filters"], tuple - ) + assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple) assert asdict_dict["zarr_format"] == 2 else: assert asdict_dict["zarr_format"] == 3 From 501e7a570dccde02b5e88a7b44f5036f5ff9395e Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 21:42:51 -0400 Subject: [PATCH 06/12] Self review --- src/zarr/codecs/sharding.py | 11 ++++++----- tests/test_codecs/test_sharding.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 8b64e68130..073320f2f1 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -778,7 +778,7 @@ async def _load_partial_shard_maybe( ) -> ShardMapping | None: """ Read chunks from `byte_getter` for the case where the read is less than a full shard. - Returns a mapping of chunk coordinates to bytes. + Returns a mapping of chunk coordinates to bytes or None. """ shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) if shard_index is None: @@ -788,11 +788,9 @@ async def _load_partial_shard_maybe( _ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) for chunk_coords in all_chunk_coords # Drop chunks where index lookup fails - # e.g. when write_empty_chunks = False and the chunk is empty + # e.g. empty chunks when write_empty_chunks = False if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) ] - if len(chunks) == 0: - return None groups = self._coalesce_chunks(chunks) @@ -816,7 +814,7 @@ def _coalesce_chunks( ) -> list[list[_ChunkCoordsByteSlice]]: """ Combine chunks from a single shard into groups that should be read together - in a single request. + in a single request to the store. Respects the following configuration options: - `sharding.read.coalesce_max_gap_bytes`: The maximum gap between @@ -828,6 +826,9 @@ def _coalesce_chunks( sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) + if len(sorted_chunks) == 0: + return [] + groups = [] current_group = [sorted_chunks[0]] diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 35940feb47..5df25d4754 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -354,7 +354,7 @@ def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( - some, but not all, chunks in the last shard are empty - the last shard is not complete (array length is not a multiple of shard shape), this takes us down the partial shard read path - - write_empty_chunks=False so the shard index will have less entries than chunks in the shard + - write_empty_chunks=False so the shard index will have fewer entries than chunks in the shard """ # array with mixed empty and non-empty chunks in second shard data = np.array([ From 12c3308452db4c56de9a13011cea0a0a77b8c5a8 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 21:44:09 -0400 Subject: [PATCH 07/12] Removing profiling code masquerading as a skipped test --- tests/test_codecs/test_sharding.py | 75 ------------------------------ 1 file changed, 75 deletions(-) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 5df25d4754..f4c2361a1d 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -204,81 +204,6 @@ def test_sharding_partial_read( assert np.all(read_data == 1) -@pytest.mark.skip("This is profiling rather than a test") -@pytest.mark.slow_hypothesis -@pytest.mark.parametrize("store", ["local"], indirect=["store"]) -def test_partial_shard_read_performance(store: Store) -> None: - import asyncio - import json - from functools import partial - from itertools import product - from timeit import timeit - from unittest.mock import AsyncMock - - # The whole test array is a single shard to keep runtime manageable while - # using a realistic shard size (256 MiB uncompressed, ~115 MiB compressed). - # In practice, the array is likely to be much larger with many shards of this - # rough order of magnitude. There are 512 chunks per shard in this example. - array_shape = (512, 512, 512) - shard_shape = (512, 512, 512) # 256 MiB uncompressed unit16s - chunk_shape = (64, 64, 64) # 512 KiB uncompressed unit16s - dtype = np.uint16 - - a = zarr.create_array( - StorePath(store), - shape=array_shape, - chunks=chunk_shape, - shards=shard_shape, - compressors=BloscCodec(cname="zstd"), - dtype=dtype, - fill_value=np.iinfo(dtype).max, - ) - # Narrow range of values lets zstd compress to about 1/2 of uncompressed size - a[:] = np.random.default_rng(123).integers(low=0, high=50, size=array_shape, dtype=dtype) - - num_calls = 20 - experiments = [] - for concurrency, get_latency, coalesce_max_gap, statement in product( - [1, 10, 100], - [0.0, 0.01], - [-1, 2**20, 10 * 2**20], - ["a[0, :, :]", "a[:, 0, :]", "a[:, :, 0]"], - ): - zarr.config.set( - { - "async.concurrency": concurrency, - "sharding.read.coalesce_max_gap_bytes": coalesce_max_gap, - } - ) - - async def get_with_latency(*args: Any, get_latency: float, **kwargs: Any) -> Any: - await asyncio.sleep(get_latency) - return await store.get(*args, **kwargs) - - store_mock = AsyncMock(wraps=store, spec=store.__class__) - store_mock.get.side_effect = partial(get_with_latency, get_latency=get_latency) - - a = zarr.open_array(StorePath(store_mock)) - - store_mock.reset_mock() - - # Each timeit call accesses a 512x512 slice covering 64 chunks - time = timeit(statement, number=num_calls, globals={"a": a}) / num_calls - experiments.append( - { - "concurrency": concurrency, - "coalesce_max_gap": coalesce_max_gap, - "get_latency": get_latency, - "statement": statement, - "time": time, - "store_get_calls": store_mock.get.call_count, - } - ) - - with open("zarr-python-partial-shard-read-performance-with-coalesce.json", "w") as f: - json.dump(experiments, f) - - @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) @pytest.mark.parametrize("coalesce_reads", [True, False]) From 6322ca63ec4cec2fd26644183c04bf8c86932d17 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 22:54:56 -0400 Subject: [PATCH 08/12] revert change to indexing.py, not needed --- src/zarr/core/indexing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 0e0bb664d8..b95c8c642e 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -76,7 +76,9 @@ def err_too_many_indices(selection: Any, shape: ChunkCoords) -> None: raise IndexError(f"too many indices for array; expected {len(shape)}, got {len(selection)}") -def _zarr_array_to_int_or_bool_array(arr: Array) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]: +def _zarr_array_to_int_or_bool_array( + arr: Array, +) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]: if arr.dtype.kind in ("i", "b"): return np.asarray(arr) else: @@ -1193,7 +1195,7 @@ def __iter__(self) -> Iterator[ChunkProjection]: stop = self.chunk_nitems_cumsum[chunk_rix] out_selection: slice | npt.NDArray[np.intp] if self.sel_sort is None: - out_selection = np.arange(start, stop) + out_selection = slice(start, stop) else: out_selection = self.sel_sort[start:stop] @@ -1318,7 +1320,8 @@ def pop_fields(selection: SelectionWithFields) -> tuple[Fields | None, Selection fields = fields[0] if len(fields) == 1 else fields selection_tuple = tuple(s for s in selection if not isinstance(s, str)) selection = cast( - "Selection", selection_tuple[0] if len(selection_tuple) == 1 else selection_tuple + "Selection", + selection_tuple[0] if len(selection_tuple) == 1 else selection_tuple, ) return fields, selection From d9a7842537a33482249284626a55dd4c4aeefcea Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 23:05:41 -0400 Subject: [PATCH 09/12] Add test for duplicate integer indexing into a coalesced group --- tests/test_codecs/test_sharding.py | 51 +++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index f4c2361a1d..c24be7d0d3 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -219,7 +219,8 @@ def test_sharding_multiple_chunks_partial_shard_read( # 1MiB, enough to coalesce all chunks within a shard in this example zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20}) else: - zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) # disable coalescing + # disable coalescing + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) store_mock = AsyncMock(wraps=store, spec=store.__class__) a = zarr.create_array( @@ -269,6 +270,54 @@ def test_sharding_multiple_chunks_partial_shard_read( assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest)) +@pytest.mark.parametrize("index_location", ["start", "end"]) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("coalesce_reads", [True, False]) +def test_sharding_duplicate_read_indexes( + store: Store, index_location: ShardingCodecIndexLocation, coalesce_reads: bool +) -> None: + """ + Check that coalesce optimization parses the grouped reads back out correctly + when there are multiple reads for the same index. + """ + array_shape = (15,) + shard_shape = (8,) + chunk_shape = (2,) + data = np.arange(np.prod(array_shape), dtype="float32").reshape(array_shape) + + if coalesce_reads: + # 1MiB, enough to coalesce all chunks within a shard in this example + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": 2**20}) + else: + # disable coalescing + zarr.config.set({"sharding.read.coalesce_max_gap_bytes": -1}) + + store_mock = AsyncMock(wraps=store, spec=store.__class__) + a = zarr.create_array( + StorePath(store_mock), + shape=data.shape, + chunks=chunk_shape, + shards={"shape": shard_shape, "index_location": index_location}, + compressors=BloscCodec(cname="lz4"), + dtype=data.dtype, + fill_value=-1, + ) + a[:] = data + + store_mock.reset_mock() # ignore store calls during array creation + + # Read the same index multiple times, do that from two chunks which can be coalesced + indexer = [8, 8, 12, 12] + np.array_equal(a[indexer], data[indexer]) + + if coalesce_reads: + # 1 shard index request + 1 coalesced read + assert store_mock.get.call_count == 2 + else: + # 1 shard index request + 2 chunks + assert store_mock.get.call_count == 3 + + @pytest.mark.parametrize("index_location", ["start", "end"]) @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( From 8469e9c0cb5b9ae222e15ccde6a75e18dfb65a84 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 23:15:00 -0400 Subject: [PATCH 10/12] Undo change to fill value when initializing shard arrays --- src/zarr/codecs/sharding.py | 4 ++-- tests/test_codecs/test_sharding.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 073320f2f1..cda60589fc 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -469,7 +469,7 @@ async def _decode_single( shape=shard_shape, dtype=shard_spec.dtype.to_native_dtype(), order=shard_spec.order, - fill_value=shard_spec.fill_value, + fill_value=0, ) shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) @@ -516,7 +516,7 @@ async def _decode_partial_single( shape=indexer.shape, dtype=shard_spec.dtype.to_native_dtype(), order=shard_spec.order, - fill_value=shard_spec.fill_value, + fill_value=0, ) indexed_chunks = list(indexer) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index c24be7d0d3..f124e14675 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -335,9 +335,9 @@ def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( # shard 0. full 8 elements, all chunks have some non-fill data 0, 1, 2, 3, 4, 5, 6, 7, # shard 1. 6 elements (< shard shape) - 2, 0, # chunk 0, written - 0, 0, # chunk 1, all fill, not written - 4, 5 # chunk 2, written + 2, 0, # chunk 0, written + -9, -9, # chunk 1, all fill, not written + 4, 5 # chunk 2, written ], dtype="int32") # fmt: off spath = StorePath(store) @@ -347,7 +347,7 @@ def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false( chunks=(2,), shards={"shape": (8,), "index_location": index_location}, dtype="int32", - fill_value=0, + fill_value=-9, filters=None, compressors=None, config={"write_empty_chunks": False}, From baf1062b625583750e4b77519b15aa49c1cdb4a9 Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 23:17:51 -0400 Subject: [PATCH 11/12] Undo change to set mypy_path = "src" --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 624af6ab4a..0b7cb9f856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -352,7 +352,6 @@ ignore = [ [tool.mypy] python_version = "3.11" ignore_missing_imports = true -mypy_path = "src" namespace_packages = false strict = true From 50d8822ab62ca347a63861dc32155d3bdcc4401d Mon Sep 17 00:00:00 2001 From: Alden Keefe Sampson Date: Mon, 21 Jul 2025 23:35:47 -0400 Subject: [PATCH 12/12] Commenting and revert uncessary changes to files for smaller diff --- src/zarr/codecs/sharding.py | 1 + src/zarr/core/indexing.py | 7 ++----- tests/test_properties.py | 35 +++++++++++++++++++++++++++++++++-- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index cda60589fc..b29e24cfb9 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -855,6 +855,7 @@ async def _get_group_bytes( Reads a possibly coalesced group of one or more chunks from a shard. Returns a mapping of chunk coordinates to bytes. """ + # _coalesce_chunks ensures that the group is not empty. group_start = group[0].byte_slice.start group_end = group[-1].byte_slice.stop diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index b95c8c642e..c11889f7f4 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -76,9 +76,7 @@ def err_too_many_indices(selection: Any, shape: ChunkCoords) -> None: raise IndexError(f"too many indices for array; expected {len(shape)}, got {len(selection)}") -def _zarr_array_to_int_or_bool_array( - arr: Array, -) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]: +def _zarr_array_to_int_or_bool_array(arr: Array) -> npt.NDArray[np.intp] | npt.NDArray[np.bool_]: if arr.dtype.kind in ("i", "b"): return np.asarray(arr) else: @@ -1320,8 +1318,7 @@ def pop_fields(selection: SelectionWithFields) -> tuple[Fields | None, Selection fields = fields[0] if len(fields) == 1 else fields selection_tuple = tuple(s for s in selection if not isinstance(s, str)) selection = cast( - "Selection", - selection_tuple[0] if len(selection_tuple) == 1 else selection_tuple, + "Selection", selection_tuple[0] if len(selection_tuple) == 1 else selection_tuple ) return fields, selection diff --git a/tests/test_properties.py b/tests/test_properties.py index e941250872..27f847fa69 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -152,13 +152,20 @@ def test_vindex(data: st.DataObject) -> None: indexer = data.draw( npst.integer_array_indices( - shape=nparray.shape, - result_shape=npst.array_shapes(min_side=1, max_dims=None), + shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None) ) ) actual = zarray.vindex[indexer] assert_array_equal(nparray[indexer], actual) + # FIXME! + # when the indexer is such that a value gets overwritten multiple times, + # I think the output depends on chunking. + # new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype)) + # nparray[indexer] = new_data + # zarray.vindex[indexer] = new_data + # assert_array_equal(nparray, zarray[:]) + @given(store=stores, meta=array_metadata()) # type: ignore[misc] @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @@ -216,6 +223,30 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}" +# @st.composite +# def advanced_indices(draw, *, shape): +# basic_idxr = draw( +# basic_indices( +# shape=shape, min_dims=len(shape), max_dims=len(shape), allow_ellipsis=False +# ).filter(lambda x: isinstance(x, tuple)) +# ) + +# int_idxr = draw( +# npst.integer_array_indices(shape=shape, result_shape=npst.array_shapes(max_dims=1)) +# ) +# args = tuple( +# st.sampled_from((l, r)) for l, r in zip_longest(basic_idxr, int_idxr, fillvalue=slice(None)) +# ) +# return draw(st.tuples(*args)) + + +# @given(st.data()) +# def test_roundtrip_object_array(data): +# nparray = data.draw(np_arrays) +# zarray = data.draw(arrays(arrays=st.just(nparray))) +# assert_array_equal(nparray, zarray[:]) + + def serialized_complex_float_is_valid( serialized: tuple[numbers.Real | str, numbers.Real | str], ) -> bool: