diff --git a/.github/workflows/pixi-test.yml b/.github/workflows/pixi-test.yml index 8716430a..7a5bdf5e 100644 --- a/.github/workflows/pixi-test.yml +++ b/.github/workflows/pixi-test.yml @@ -15,7 +15,7 @@ jobs: - uses: prefix-dev/setup-pixi@v0.8.1 with: - pixi-version: v0.39.3 + pixi-version: v0.44.0 frozen: true - name: Run tests diff --git a/ngff_zarr/to_ngff_zarr.py b/ngff_zarr/to_ngff_zarr.py index 2c2e2c36..b99aa141 100644 --- a/ngff_zarr/to_ngff_zarr.py +++ b/ngff_zarr/to_ngff_zarr.py @@ -2,7 +2,7 @@ from collections.abc import MutableMapping from dataclasses import asdict from pathlib import Path, PurePosixPath -from typing import Optional, Union +from typing import Optional, Union, Tuple, Dict from packaging import version if sys.version_info < (3, 10): @@ -150,6 +150,13 @@ def to_ngff_zarr( use_tensorstore: bool = False, chunk_store: Optional[StoreLike] = None, progress: Optional[Union[NgffProgress, NgffProgressCallback]] = None, + chunks_per_shard: Optional[ + Union[ + int, + Tuple[int, ...], + Dict[str, int], + ] + ] = None, **kwargs, ) -> None: """ @@ -174,8 +181,11 @@ def to_ngff_zarr( for storage of both chunks and metadata. :type chunk_store: StoreLike, optional - :type progress: RichDaskProgress :param progress: Optional progress logger + :type progress: RichDaskProgress + + :param chunks_per_shard: Number of chunks along each axis in a shard. If None, no sharding. Requires zarr version >= 0.5. + :type chunks_per_shard: int, tuple, or dict, optional :param **kwargs: Passed to the zarr.creation.create() function, e.g., compression options. """ @@ -189,6 +199,16 @@ def to_ngff_zarr( if version != "0.4" and version != "0.5": raise ValueError(f"Unsupported version: {version}") + if chunks_per_shard is not None: + if version == "0.4": + raise ValueError( + "Sharding is only supported for OME-Zarr version 0.5 and later" + ) + if not use_tensorstore and zarr_version_major < 3: + raise ValueError( + "Sharding requires zarr-python version >= 3.0.0b1 for OME-Zarr version >= 0.5" + ) + metadata = multiscales.metadata if version == "0.4" and isinstance(metadata, Metadata_v05): metadata = Metadata_v04( @@ -268,6 +288,39 @@ def to_ngff_zarr( dim_factors = {d: 1 for d in dims} previous_dim_factors = dim_factors + sharding_kwargs = {} + if chunks_per_shard is not None: + c0 = tuple([c[0] for c in arr.chunks]) + if isinstance(chunks_per_shard, int): + shards = tuple([c * chunks_per_shard for c in c0]) + elif isinstance(chunks_per_shard, (tuple, list)): + if len(chunks_per_shard) != arr.ndim: + raise ValueError( + f"chunks_per_shard must be a tuple of length {arr.ndim}" + ) + shards = tuple([c * c0[i] for i, c in enumerate(chunks_per_shard)]) + elif isinstance(chunks_per_shard, dict): + shards = {d: c * chunks_per_shard.get(d, 1) for d, c in zip(dims, c0)} + shards = tuple([shards[d] for d in dims]) + else: + raise ValueError("chunks_per_shard must be an int, tuple, or dict") + from zarr.codecs.sharding import ShardingCodec + + if "codec" in kwargs: + nested_codec = kwargs.pop("codec") + sharding_codec = ShardingCodec( + chunk_shape=c0, + codec=nested_codec, + ) + else: + sharding_codec = ShardingCodec(chunk_shape=c0) + if "codecs" in kwargs: + previous_codecs = kwargs.pop("codecs") + sharding_kwargs["codecs"] = previous_codecs + [sharding_codec] + else: + sharding_kwargs["codecs"] = [sharding_codec] + arr = arr.rechunk(shards) + if memory_usage(image) > config.memory_target and multiscales.scale_factors: shrink_factors = [] for dim in dims: @@ -285,6 +338,7 @@ def to_ngff_zarr( store=store, path=path, mode="a", + **sharding_kwargs, **zarr_kwargs, **dimension_names_kwargs, **format_kwargs, @@ -433,6 +487,7 @@ def to_ngff_zarr( overwrite=False, compute=True, return_stored=False, + **sharding_kwargs, **zarr_kwargs, **format_kwargs, **dimension_names_kwargs, @@ -464,6 +519,7 @@ def to_ngff_zarr( overwrite=False, compute=True, return_stored=False, + **sharding_kwargs, **zarr_kwargs, **format_kwargs, **dimension_names_kwargs, diff --git a/pixi.lock b/pixi.lock index fd3d09ac..2ffcb11e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1591,7 +1591,7 @@ packages: - pytest>=7.0 ; extra == 'test' - trustme ; extra == 'test' - truststore>=0.9.1 ; python_full_version >= '3.10' and extra == 'test' - - uvloop>=0.21 ; python_full_version < '3.14' and platform_python_implementation == 'CPython' and platform_system != 'Windows' and extra == 'test' + - uvloop>=0.21 ; python_full_version < '3.14' and platform_python_implementation == 'CPython' and sys_platform != 'win32' and extra == 'test' - packaging ; extra == 'doc' - sphinx~=7.4 ; extra == 'doc' - sphinx-rtd-theme ; extra == 'doc' @@ -1959,7 +1959,7 @@ packages: version: 8.1.8 sha256: 63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2 requires_dist: - - colorama ; platform_system == 'Windows' + - colorama ; sys_platform == 'win32' - importlib-metadata ; python_full_version < '3.8' requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl @@ -3495,7 +3495,7 @@ packages: timestamp: 1736683572099 - pypi: . name: ngff-zarr - version: 0.11.0 + version: 0.12.2 sha256: d8f01f63ea17ad4d8e7fa6bd0afd6cf5d84ee41cbc6d65dc1d82ad1ac13eaffb requires_dist: - dask[array] @@ -3573,19 +3573,19 @@ packages: sha256: 00a364924fd2d600bcce6e2ced96b47c40eb5f9d84bf4b0207aa208d9ce6cd1c requires_dist: - numpy>=1.24 - - crc32c>=2.7 ; extra == 'crc32c' - sphinx ; extra == 'docs' - sphinx-issues ; extra == 'docs' - pydata-sphinx-theme ; extra == 'docs' - numpydoc ; extra == 'docs' - - msgpack ; extra == 'msgpack' - - pcodec>=0.2,<0.3 ; extra == 'pcodec' - coverage ; extra == 'test' - pytest ; extra == 'test' - pytest-cov ; extra == 'test' - importlib-metadata ; extra == 'test-extras' + - msgpack ; extra == 'msgpack' - zfpy>=1.0.0 ; extra == 'zfpy' - numpy<2.0.0 ; extra == 'zfpy' + - pcodec>=0.2,<0.3 ; extra == 'pcodec' + - crc32c>=2.7 ; extra == 'crc32c' requires_python: '>=3.11' - pypi: https://files.pythonhosted.org/packages/4b/e2/ac784ac4b6e5841e4bfb7d3e7e38497450df18eebff9465990a8ac9aecfc/numcodecs-0.14.1-cp313-cp313-win_amd64.whl name: numcodecs diff --git a/test/test_to_ngff_zarr_sharding.py b/test/test_to_ngff_zarr_sharding.py new file mode 100644 index 00000000..a817d236 --- /dev/null +++ b/test/test_to_ngff_zarr_sharding.py @@ -0,0 +1,103 @@ +import json +from packaging import version +import tempfile + +import pytest + +import zarr.storage +import zarr + +from ngff_zarr import Methods, to_multiscales, to_ngff_zarr + +from ._data import verify_against_baseline + +zarr_version = version.parse(zarr.__version__) + +# Skip tests if zarr version is less than 3.0.0b1 +pytestmark = pytest.mark.skipif( + zarr_version < version.parse("3.0.0b1"), reason="zarr version < 3.0.0b1" +) + + +def test_zarr_python_sharding(input_images): + dataset_name = "cthead1" + image = input_images[dataset_name] + baseline_name = "2_4/RFC3_GAUSSIAN.zarr" + chunks = (64, 64) + multiscales = to_multiscales( + image, [2, 4], chunks=chunks, method=Methods.ITKWASM_GAUSSIAN + ) + store = zarr.storage.MemoryStore() + + chunks_per_shard = 2 + version = "0.4" + with pytest.raises(ValueError): + to_ngff_zarr( + store, multiscales, version=version, chunks_per_shard=chunks_per_shard + ) + + version = "0.5" + with tempfile.TemporaryDirectory() as tmpdir: + to_ngff_zarr( + tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard + ) + with open(tmpdir + "/zarr.json") as f: + zarr_json = json.load(f) + assert zarr_json["zarr_format"] == 3 + metadata = zarr_json["consolidated_metadata"]["metadata"] + scale0 = metadata["scale0/image"] + assert scale0["shape"][0] == 256 + assert scale0["shape"][1] == 256 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 128 + assert scale0["codecs"][0]["name"] == "sharding_indexed" + assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64 + assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64 + + verify_against_baseline( + dataset_name, baseline_name, multiscales, version=version + ) + + chunks_per_shard = (2, 1) + with tempfile.TemporaryDirectory() as tmpdir: + to_ngff_zarr( + tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard + ) + with open(tmpdir + "/zarr.json") as f: + zarr_json = json.load(f) + assert zarr_json["zarr_format"] == 3 + metadata = zarr_json["consolidated_metadata"]["metadata"] + scale0 = metadata["scale0/image"] + assert scale0["shape"][0] == 256 + assert scale0["shape"][1] == 256 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 64 + assert scale0["codecs"][0]["name"] == "sharding_indexed" + assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64 + assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64 + + verify_against_baseline( + dataset_name, baseline_name, multiscales, version=version + ) + + chunks_per_shard = {"y": 2, "x": 1} + with tempfile.TemporaryDirectory() as tmpdir: + to_ngff_zarr( + tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard + ) + with open(tmpdir + "/zarr.json") as f: + zarr_json = json.load(f) + assert zarr_json["zarr_format"] == 3 + metadata = zarr_json["consolidated_metadata"]["metadata"] + scale0 = metadata["scale0/image"] + assert scale0["shape"][0] == 256 + assert scale0["shape"][1] == 256 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128 + assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 64 + assert scale0["codecs"][0]["name"] == "sharding_indexed" + assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64 + assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64 + + verify_against_baseline( + dataset_name, baseline_name, multiscales, version=version + )