From c85b2f40ae0c087b9400b76bde365e1b5fecb52e Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Fri, 10 Jan 2025 15:58:14 -0500 Subject: [PATCH] ENH: Support time in ngff_image_to_itk_image Can either result in a spatial dimension + 1 dimension ITK Image or a time index can be specified. --- ngff_zarr/ngff_image_to_itk_image.py | 46 +++++++++++++++++++++++----- test/test_ngff_image_to_itk_image.py | 29 +++++++++++++++++- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/ngff_zarr/ngff_image_to_itk_image.py b/ngff_zarr/ngff_image_to_itk_image.py index c2ab37bb..f31bff0f 100644 --- a/ngff_zarr/ngff_image_to_itk_image.py +++ b/ngff_zarr/ngff_image_to_itk_image.py @@ -1,6 +1,8 @@ +from typing import Optional + import numpy as np +from dask.array.core import Array as DaskArray -from .methods._support import _spatial_dims from .ngff_image import NgffImage from .methods._support import _channel_dim_last @@ -35,13 +37,47 @@ def _dtype_to_component_type(dtype): def ngff_image_to_itk_image( ngff_image: NgffImage, wasm: bool = True, + t_index: Optional[int] = None, ): + """Convert a NgffImage to an ITK image.""" from itkwasm import IntTypes, PixelTypes + if t_index is not None and "t" in ngff_image.dims: + t_dim_index = ngff_image.dims.index("t") + new_dims = list(ngff_image.dims) + new_dims.remove("t") + new_dims = tuple(new_dims) + new_scale = {dim: ngff_image.scale[dim] for dim in new_dims} + new_translation = {dim: ngff_image.translation[dim] for dim in new_dims} + new_axes_units = {dim: ngff_image.axes_units[dim] for dim in new_dims} + if isinstance(ngff_image.data, DaskArray): + from dask.array import take + + new_data = take(ngff_image.data, t_index, axis=t_dim_index) + else: + new_data = ngff_image.data.take(t_index, axis=t_dim_index) + ngff_image = NgffImage( + data=new_data, + dims=new_dims, + name=ngff_image.name, + scale=new_scale, + translation=new_translation, + axes_units=new_axes_units, + ) + ngff_image = _channel_dim_last(ngff_image) dims = ngff_image.dims - dimension = 3 if "z" in dims else 2 + itk_dimension_names = {"x", "y", "z", "t"} + itk_dims = [dim for dim in dims if dim in itk_dimension_names] + itk_dims.sort() + if "t" in itk_dims: + itk_dims.remove("t") + itk_dims.append("t") + spacing = [ngff_image.scale[dim] for dim in itk_dims] + origin = [ngff_image.translation[dim] for dim in itk_dims] + size = [ngff_image.data.shape[dims.index(d)] for d in itk_dims] + dimension = len(itk_dims) componentType = _dtype_to_component_type(ngff_image.data.dtype) @@ -60,12 +96,6 @@ def ngff_image_to_itk_image( "components": components, } - spatial_dims = [dim for dim in dims if dim in _spatial_dims] - spatial_dims.sort() - spacing = [ngff_image.scale[dim] for dim in spatial_dims] - origin = [ngff_image.translation[dim] for dim in spatial_dims] - size = [ngff_image.data.shape[dims.index(d)] for d in spatial_dims] - data = np.asarray(ngff_image.data) image_dict = { diff --git a/test/test_ngff_image_to_itk_image.py b/test/test_ngff_image_to_itk_image.py index eec258ae..e88ae21d 100644 --- a/test/test_ngff_image_to_itk_image.py +++ b/test/test_ngff_image_to_itk_image.py @@ -1,7 +1,7 @@ import itk import itkwasm import numpy as np -from ngff_zarr import itk_image_to_ngff_image, ngff_image_to_itk_image +from ngff_zarr import itk_image_to_ngff_image, ngff_image_to_itk_image, from_ngff_zarr from ._data import test_data_dir @@ -62,3 +62,30 @@ def test_2d_itkwasm_image(input_images): # noqa: ARG001 assert np.array_equal( np.asarray(itkwasm_image.data), np.asarray(itkwasm_image_back.data) ) + + +def test_t_index(input_images): # noqa: ARG001 + dataset_name = "13457537" + store_path = test_data_dir / "input" / f"{dataset_name}.zarr" + multiscales = from_ngff_zarr(store_path) + ngff_image = multiscales.images[0] + + itk_image = ngff_image_to_itk_image(ngff_image) + + assert itk_image.imageType.dimension == 4 + assert itk_image.imageType.components == 6 + assert len(itk_image.size) == 4 + assert len(itk_image.spacing) == 4 + assert len(itk_image.origin) == 4 + assert len(itk_image.direction) == 4 + assert itk_image.data.shape == (18, 12, 223, 198, 6) + + itk_image = ngff_image_to_itk_image(ngff_image, t_index=0) + + assert itk_image.imageType.dimension == 3 + assert itk_image.imageType.components == 6 + assert len(itk_image.size) == 3 + assert len(itk_image.spacing) == 3 + assert len(itk_image.origin) == 3 + assert len(itk_image.direction) == 3 + assert itk_image.data.shape == (12, 223, 198, 6)