Skip to content

Commit e6c77e1

Browse files
committed
WIP: tczyz support
1 parent 17f088c commit e6c77e1

File tree

4 files changed

+107
-18
lines changed

4 files changed

+107
-18
lines changed

ngff_zarr/methods/_itkwasm.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_dim_scale_factors,
1111
_get_block,
1212
_spatial_dims,
13+
_spatial_dims_last,
1314
)
1415

1516
_image_dims: Tuple[str, str, str, str] = ("x", "y", "z", "t")
@@ -89,13 +90,14 @@ def _downsample_itkwasm_bin_shrink(
8990
dim_factors = _dim_scale_factors(dims, scale_factor, previous_dim_factors)
9091
previous_dim_factors = dim_factors
9192
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
93+
previous_image = _spatial_dims_last(previous_image)
9294

9395
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
9496

9597
block_0 = _get_block(previous_image, 0)
9698

9799
# For consistency for now, do not utilize direction until there is standardized support for
98-
# direction cosines / orientation in OME-NGFF
100+
# direction cosines / orientation in OME-NGFF (v0.6)
99101
# block_0.attrs.pop("direction", None)
100102
block_input = itkwasm.image_from_array(np.ones_like(block_0))
101103
spacing = [previous_image.scale[d] for d in spatial_dims]
@@ -166,6 +168,7 @@ def _downsample_itkwasm(
166168
dim_factors = _dim_scale_factors(dims, scale_factor, previous_dim_factors)
167169
previous_dim_factors = dim_factors
168170
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
171+
previous_image = _spatial_dims_last(previous_image)
169172

170173
shrink_factors = [dim_factors[sd] for sd in spatial_dims]
171174

@@ -177,7 +180,10 @@ def _downsample_itkwasm(
177180
block_neg1_input = _get_block(previous_image, -1)
178181

179182
# Compute overlap for Gaussian blurring for all blocks
180-
block_0_image = itkwasm.image_from_array(np.ones_like(block_0_input))
183+
is_vector = previous_image.dims[-1] == "c"
184+
block_0_image = itkwasm.image_from_array(
185+
np.ones_like(block_0_input), is_vector=is_vector
186+
)
181187
input_spacing = [previous_image.scale[d] for d in spatial_dims]
182188
block_0_image.spacing = input_spacing
183189
input_origin = [previous_image.translation[d] for d in spatial_dims]
@@ -208,17 +214,25 @@ def _downsample_itkwasm(
208214
block_output.size[dim] == computed_size[dim]
209215
for dim in range(block_output.data.ndim)
210216
)
217+
breakpoint()
211218
output_chunks = list(previous_image.data.chunks)
212-
if "t" in previous_image.dims:
213-
dims = list(previous_image.dims)
214-
t_index = dims.index("t")
215-
output_chunks.pop(t_index)
219+
dims = list(previous_image.dims)
220+
output_chunks_start = 0
221+
while dims[output_chunks_start] not in _spatial_dims:
222+
output_chunks_start += 1
223+
output_chunks = output_chunks[output_chunks_start:]
224+
# if "t" in previous_image.dims:
225+
# dims = list(previous_image.dims)
226+
# t_index = dims.index("t")
227+
# output_chunks.pop(t_index)
216228
for i, c in enumerate(output_chunks):
217229
output_chunks[i] = [
218230
block_output.data.shape[i],
219231
] * len(c)
220232
# Compute output size for block N-1
221-
block_neg1_image = itkwasm.image_from_array(np.ones_like(block_neg1_input))
233+
block_neg1_image = itkwasm.image_from_array(
234+
np.ones_like(block_neg1_input), is_vector=is_vector
235+
)
222236
block_neg1_image.spacing = input_spacing
223237
block_neg1_image.origin = input_origin
224238
block_output = downsample_bin_shrink(

ngff_zarr/methods/_support.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,75 @@
11
from typing import List
2+
import copy
23

3-
from dask.array import take
44

55
from ..ngff_image import NgffImage
66

77
_spatial_dims = {"x", "y", "z"}
88

9+
_spatial_dims = {"x", "y", "z"}
10+
11+
12+
def _spatial_dims_last(ngff_image: NgffImage) -> NgffImage:
13+
dims = list(ngff_image.dims)
14+
spatial_dims = [dim for dim in dims if dim in _spatial_dims]
15+
16+
dims_spatial_channel = len(spatial_dims)
17+
if dims[-1] == "c":
18+
dims_spatial_channel += 1
19+
20+
# If spatial dimensions are already last (and 'c' can be last), return the original image
21+
if all(dim in dims[-dims_spatial_channel:] for dim in spatial_dims + ["c"]):
22+
return ngff_image
23+
24+
# Move spatial dimensions to the end, keeping 'c' as the last pre-spatial dimension if present
25+
non_spatial_dims = [dim for dim in dims if dim not in _spatial_dims]
26+
if "c" in non_spatial_dims:
27+
non_spatial_dims.remove("c")
28+
new_dims = non_spatial_dims + ["c"] + spatial_dims
29+
else:
30+
new_dims = non_spatial_dims + spatial_dims
31+
32+
new_order = [dims.index(dim) for dim in new_dims]
33+
34+
if tuple(new_dims) == tuple(ngff_image.dims):
35+
return ngff_image
36+
37+
# Reorder the data array
38+
reordered_data = ngff_image.data.transpose(new_order)
39+
40+
result = copy.copy(ngff_image)
41+
result.data = reordered_data
42+
result.dims = tuple(new_dims)
43+
44+
return result
45+
46+
47+
def _channel_dim_last(ngff_image: NgffImage) -> NgffImage:
48+
if "c" not in ngff_image.dims or ngff_image.dims[-1] == "c":
49+
return ngff_image
50+
51+
dims = list(ngff_image.dims)
52+
# Move 'c' dimension to the end
53+
dims.remove("c")
54+
dims.append("c")
55+
56+
# Reorder the data array
57+
new_order = [ngff_image.dims.index(dim) for dim in dims]
58+
reordered_data = ngff_image.data.transpose(new_order)
59+
60+
result = copy.copy(ngff_image)
61+
result.data = reordered_data
62+
result.dims = tuple(dims)
63+
64+
return result
65+
966

1067
def _dim_scale_factors(dims, scale_factor, previous_dim_factors):
1168
if isinstance(scale_factor, int):
1269
result_scale_factors = {
1370
dim: int(scale_factor / previous_dim_factors[dim])
14-
for dim in dims if dim in _spatial_dims
71+
for dim in dims
72+
if dim in _spatial_dims
1573
}
1674
else:
1775
result_scale_factors = {
@@ -71,11 +129,15 @@ def _get_block(previous_image: NgffImage, block_index: int):
71129
"""Helper method for accessing an enumerated chunk from input"""
72130
block_shape = [c[block_index] for c in previous_image.data.chunks]
73131
block = previous_image.data[tuple([slice(0, s) for s in block_shape])]
74-
# For consistency for now, do not utilize direction until there is standardized support for
75-
# direction cosines / orientation in OME-NGFF
76-
# block.attrs.pop("direction", None)
77-
if "t" in previous_image.dims:
78-
dims = list(previous_image.dims)
79-
t_index = dims.index("t")
80-
block = take(block, 0, t_index)
132+
dims = list(previous_image.dims)
133+
# Also take "c" if it is the last dimension
134+
if dims[-1] == "c":
135+
dims[-1] = "x"
136+
indexer = []
137+
for d in dims:
138+
if d in _spatial_dims:
139+
indexer.append(slice(None))
140+
else:
141+
indexer.append(0)
142+
block = block[tuple(indexer)]
81143
return block

ngff_zarr/ngff_image_to_itk_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .methods._support import _spatial_dims
44
from .ngff_image import NgffImage
5+
from .methods._support import _channel_dim_last
56

67

78
def _dtype_to_component_type(dtype):
@@ -37,6 +38,8 @@ def ngff_image_to_itk_image(
3738
):
3839
from itkwasm import IntTypes, PixelTypes
3940

41+
ngff_image = _channel_dim_last(ngff_image)
42+
4043
dims = ngff_image.dims
4144
dimension = 3 if "z" in dims else 2
4245

@@ -63,7 +66,6 @@ def ngff_image_to_itk_image(
6366
origin = [ngff_image.translation[dim] for dim in spatial_dims]
6467
size = [ngff_image.data.shape[dims.index(d)] for d in spatial_dims]
6568

66-
# TODO: reorder as needed
6769
data = np.asarray(ngff_image.data)
6870

6971
image_dict = {

test/test_to_ngff_zarr_itkwasm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from ngff_zarr import Methods, to_multiscales
1+
import numpy as np
2+
from zarr.storage import MemoryStore
3+
4+
from ngff_zarr import Methods, to_multiscales, to_ngff_image, to_ngff_zarr
25

36
from ._data import verify_against_baseline
47

@@ -11,6 +14,14 @@
1114
pass
1215

1316

17+
def test_channel_support():
18+
data = np.random.randint(0, 256, 16777216).reshape((2, 128, 256, 256))
19+
image = to_ngff_image(data, dims=["c", "z", "y", "x"])
20+
multiscales = to_multiscales(image, scale_factors=[2, 4], chunks=64)
21+
store = MemoryStore()
22+
to_ngff_zarr(store, multiscales)
23+
24+
1425
def test_bin_shrink_isotropic_scale_factors(input_images):
1526
dataset_name = "cthead1"
1627
image = input_images[dataset_name]

0 commit comments

Comments
 (0)