Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pixi-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

- uses: prefix-dev/setup-pixi@v0.8.1
with:
pixi-version: v0.39.3
pixi-version: v0.44.0
frozen: true

- name: Run tests
Expand Down
61 changes: 59 additions & 2 deletions ngff_zarr/to_ngff_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import MutableMapping
from dataclasses import asdict
from pathlib import Path, PurePosixPath
from typing import Optional, Union
from typing import Optional, Union, Tuple, Mapping, Any
from packaging import version

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -150,6 +150,14 @@ def to_ngff_zarr(
use_tensorstore: bool = False,
chunk_store: Optional[StoreLike] = None,
progress: Optional[Union[NgffProgress, NgffProgressCallback]] = None,
chunks_per_shard: Optional[
Union[
int,
Tuple[int, ...],
Tuple[Tuple[int, ...], ...],
Mapping[Any, Union[None, int, Tuple[int, ...]]],
]
] = None,
**kwargs,
) -> None:
"""
Expand All @@ -174,8 +182,11 @@ def to_ngff_zarr(
for storage of both chunks and metadata.
:type chunk_store: StoreLike, optional

:type progress: RichDaskProgress
:param progress: Optional progress logger
:type progress: RichDaskProgress

:param chunks_per_shard: Number of chunks along each axis in a shard. If None, no sharding. Requires zarr version >= 0.5.
:type chunks_per_shard: int, tuple, or dict, optional

:param **kwargs: Passed to the zarr.creation.create() function, e.g., compression options.
"""
Expand All @@ -189,6 +200,16 @@ def to_ngff_zarr(
if version != "0.4" and version != "0.5":
raise ValueError(f"Unsupported version: {version}")

if chunks_per_shard is not None:
if version == "0.4":
raise ValueError(
"Sharding is only supported for OME-Zarr version 0.5 and later"
)
if not use_tensorstore and zarr_version_major < 3:
raise ValueError(
"Sharding requires zarr-python version >= 3.0.0b1 for OME-Zarr version >= 0.5"
)

metadata = multiscales.metadata
if version == "0.4" and isinstance(metadata, Metadata_v05):
metadata = Metadata_v04(
Expand Down Expand Up @@ -268,6 +289,39 @@ def to_ngff_zarr(
dim_factors = {d: 1 for d in dims}
previous_dim_factors = dim_factors

sharding_kwargs = {}
if chunks_per_shard is not None:
c0 = tuple([c[0] for c in arr.chunks])
if isinstance(chunks_per_shard, int):
shards = tuple([c * chunks_per_shard for c in c0])
elif isinstance(chunks_per_shard, (tuple, list)):
if len(chunks_per_shard) != arr.ndim:
raise ValueError(
f"chunks_per_shard must be a tuple of length {arr.ndim}"
)
shards = tuple([c * c0[i] for i, c in enumerate(chunks_per_shard)])
elif isinstance(chunks_per_shard, dict):
shards = {d: c * chunks_per_shard.get(d, 1) for d, c in zip(dims, c0)}
shards = tuple([shards[d] for d in dims])
else:
raise ValueError("chunks_per_shard must be an int, tuple, or dict")
from zarr.codecs.sharding import ShardingCodec

if "codec" in kwargs:
nested_codec = kwargs.pop("codec")
sharding_codec = ShardingCodec(
chunk_shape=c0,
codec=nested_codec,
)
else:
sharding_codec = ShardingCodec(chunk_shape=c0)
if "codecs" in kwargs:
previous_codecs = kwargs.pop("codecs")
sharding_kwargs["codecs"] = previous_codecs.append(sharding_codec)
else:
sharding_kwargs["codecs"] = [sharding_codec]
arr = arr.rechunk(shards)

if memory_usage(image) > config.memory_target and multiscales.scale_factors:
shrink_factors = []
for dim in dims:
Expand All @@ -285,6 +339,7 @@ def to_ngff_zarr(
store=store,
path=path,
mode="a",
**sharding_kwargs,
**zarr_kwargs,
**dimension_names_kwargs,
**format_kwargs,
Expand Down Expand Up @@ -433,6 +488,7 @@ def to_ngff_zarr(
overwrite=False,
compute=True,
return_stored=False,
**sharding_kwargs,
**zarr_kwargs,
**format_kwargs,
**dimension_names_kwargs,
Expand Down Expand Up @@ -464,6 +520,7 @@ def to_ngff_zarr(
overwrite=False,
compute=True,
return_stored=False,
**sharding_kwargs,
**zarr_kwargs,
**format_kwargs,
**dimension_names_kwargs,
Expand Down
12 changes: 6 additions & 6 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 103 additions & 0 deletions test/test_to_ngff_zarr_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
from packaging import version
import tempfile

import pytest

import zarr.storage
import zarr

from ngff_zarr import Methods, to_multiscales, to_ngff_zarr

from ._data import verify_against_baseline

zarr_version = version.parse(zarr.__version__)

# Skip tests if zarr version is less than 3.0.0b1
pytestmark = pytest.mark.skipif(
zarr_version < version.parse("3.0.0b1"), reason="zarr version < 3.0.0b1"
)


def test_zarr_python_sharding(input_images):
dataset_name = "cthead1"
image = input_images[dataset_name]
baseline_name = "2_4/RFC3_GAUSSIAN.zarr"
chunks = (64, 64)
multiscales = to_multiscales(
image, [2, 4], chunks=chunks, method=Methods.ITKWASM_GAUSSIAN
)
store = zarr.storage.MemoryStore()

chunks_per_shard = 2
version = "0.4"
with pytest.raises(ValueError):
to_ngff_zarr(
store, multiscales, version=version, chunks_per_shard=chunks_per_shard
)

version = "0.5"
with tempfile.TemporaryDirectory() as tmpdir:
to_ngff_zarr(
tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard
)
with open(tmpdir + "/zarr.json") as f:
zarr_json = json.load(f)
assert zarr_json["zarr_format"] == 3
metadata = zarr_json["consolidated_metadata"]["metadata"]
scale0 = metadata["scale0/image"]
assert scale0["shape"][0] == 256
assert scale0["shape"][1] == 256
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 128
assert scale0["codecs"][0]["name"] == "sharding_indexed"
assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64
assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64

verify_against_baseline(
dataset_name, baseline_name, multiscales, version=version
)

chunks_per_shard = (2, 1)
with tempfile.TemporaryDirectory() as tmpdir:
to_ngff_zarr(
tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard
)
with open(tmpdir + "/zarr.json") as f:
zarr_json = json.load(f)
assert zarr_json["zarr_format"] == 3
metadata = zarr_json["consolidated_metadata"]["metadata"]
scale0 = metadata["scale0/image"]
assert scale0["shape"][0] == 256
assert scale0["shape"][1] == 256
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 64
assert scale0["codecs"][0]["name"] == "sharding_indexed"
assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64
assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64

verify_against_baseline(
dataset_name, baseline_name, multiscales, version=version
)

chunks_per_shard = {"y": 2, "x": 1}
with tempfile.TemporaryDirectory() as tmpdir:
to_ngff_zarr(
tmpdir, multiscales, version=version, chunks_per_shard=chunks_per_shard
)
with open(tmpdir + "/zarr.json") as f:
zarr_json = json.load(f)
assert zarr_json["zarr_format"] == 3
metadata = zarr_json["consolidated_metadata"]["metadata"]
scale0 = metadata["scale0/image"]
assert scale0["shape"][0] == 256
assert scale0["shape"][1] == 256
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][0] == 128
assert scale0["chunk_grid"]["configuration"]["chunk_shape"][1] == 64
assert scale0["codecs"][0]["name"] == "sharding_indexed"
assert scale0["codecs"][0]["configuration"]["chunk_shape"][0] == 64
assert scale0["codecs"][0]["configuration"]["chunk_shape"][1] == 64

verify_against_baseline(
dataset_name, baseline_name, multiscales, version=version
)
Loading