Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions ngff_zarr/ngff_image_to_itk_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import numpy as np
from dask.array.core import Array as DaskArray

from .methods._support import _spatial_dims
from .ngff_image import NgffImage
from .methods._support import _channel_dim_last

Expand Down Expand Up @@ -35,13 +37,47 @@ def _dtype_to_component_type(dtype):
def ngff_image_to_itk_image(
ngff_image: NgffImage,
wasm: bool = True,
t_index: Optional[int] = None,
):
"""Convert a NgffImage to an ITK image."""
from itkwasm import IntTypes, PixelTypes

if t_index is not None and "t" in ngff_image.dims:
t_dim_index = ngff_image.dims.index("t")
new_dims = list(ngff_image.dims)
new_dims.remove("t")
new_dims = tuple(new_dims)
new_scale = {dim: ngff_image.scale[dim] for dim in new_dims}
new_translation = {dim: ngff_image.translation[dim] for dim in new_dims}
new_axes_units = {dim: ngff_image.axes_units[dim] for dim in new_dims}
if isinstance(ngff_image.data, DaskArray):
from dask.array import take

new_data = take(ngff_image.data, t_index, axis=t_dim_index)
else:
new_data = ngff_image.data.take(t_index, axis=t_dim_index)
ngff_image = NgffImage(
data=new_data,
dims=new_dims,
name=ngff_image.name,
scale=new_scale,
translation=new_translation,
axes_units=new_axes_units,
)

ngff_image = _channel_dim_last(ngff_image)

dims = ngff_image.dims
dimension = 3 if "z" in dims else 2
itk_dimension_names = {"x", "y", "z", "t"}
itk_dims = [dim for dim in dims if dim in itk_dimension_names]
itk_dims.sort()
if "t" in itk_dims:
itk_dims.remove("t")
itk_dims.append("t")
spacing = [ngff_image.scale[dim] for dim in itk_dims]
origin = [ngff_image.translation[dim] for dim in itk_dims]
size = [ngff_image.data.shape[dims.index(d)] for d in itk_dims]
dimension = len(itk_dims)

componentType = _dtype_to_component_type(ngff_image.data.dtype)

Expand All @@ -60,12 +96,6 @@ def ngff_image_to_itk_image(
"components": components,
}

spatial_dims = [dim for dim in dims if dim in _spatial_dims]
spatial_dims.sort()
spacing = [ngff_image.scale[dim] for dim in spatial_dims]
origin = [ngff_image.translation[dim] for dim in spatial_dims]
size = [ngff_image.data.shape[dims.index(d)] for d in spatial_dims]

data = np.asarray(ngff_image.data)

image_dict = {
Expand Down
29 changes: 28 additions & 1 deletion test/test_ngff_image_to_itk_image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itk
import itkwasm
import numpy as np
from ngff_zarr import itk_image_to_ngff_image, ngff_image_to_itk_image
from ngff_zarr import itk_image_to_ngff_image, ngff_image_to_itk_image, from_ngff_zarr

from ._data import test_data_dir

Expand Down Expand Up @@ -62,3 +62,30 @@ def test_2d_itkwasm_image(input_images): # noqa: ARG001
assert np.array_equal(
np.asarray(itkwasm_image.data), np.asarray(itkwasm_image_back.data)
)


def test_t_index(input_images): # noqa: ARG001
dataset_name = "13457537"
store_path = test_data_dir / "input" / f"{dataset_name}.zarr"
multiscales = from_ngff_zarr(store_path)
ngff_image = multiscales.images[0]

itk_image = ngff_image_to_itk_image(ngff_image)

assert itk_image.imageType.dimension == 4
assert itk_image.imageType.components == 6
assert len(itk_image.size) == 4
assert len(itk_image.spacing) == 4
assert len(itk_image.origin) == 4
assert len(itk_image.direction) == 4
assert itk_image.data.shape == (18, 12, 223, 198, 6)

itk_image = ngff_image_to_itk_image(ngff_image, t_index=0)

assert itk_image.imageType.dimension == 3
assert itk_image.imageType.components == 6
assert len(itk_image.size) == 3
assert len(itk_image.spacing) == 3
assert len(itk_image.origin) == 3
assert len(itk_image.direction) == 3
assert itk_image.data.shape == (12, 223, 198, 6)
Loading