Skip to content

Commit eeefe4a

Browse files
committed
BUG: tzyxc bin shrink crash
Punt on ITK_BIN_SHRINK, but throw a more informative error. Add multi-component support for ITKWASM_BIN_SHINK. Re: #157
1 parent 4ac1fc0 commit eeefe4a

File tree

7 files changed

+225
-118
lines changed

7 files changed

+225
-118
lines changed

ngff_zarr/methods/_itk.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def _downsample_itk_bin_shrink(
149149
# For consistency for now, do not utilize direction until there is standardized support for
150150
# direction cosines / orientation in OME-NGFF
151151
# block_0.attrs.pop("direction", None)
152+
if "c" in previous_image.dims:
153+
raise ValueError(
154+
"Downsampling with ITK BinShrinkImageFilter does not support channel dimension 'c'. "
155+
"Use ITK Gaussian downsampling instead."
156+
)
152157
block_input = itk.image_from_array(np.ones_like(block_0))
153158
spacing = [previous_image.scale[d] for d in spatial_dims]
154159
block_input.SetSpacing(spacing)

ngff_zarr/methods/_itkwasm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _downsample_itkwasm(
146146
]
147147
assert all(
148148
block_output.size[dim] == computed_size[dim]
149-
for dim in range(block_output.data.ndim)
149+
for dim in range(len(block_output.size))
150150
)
151151
output_chunks = list(previous_image.data.chunks)
152152
output_chunks_start = 0
@@ -172,7 +172,7 @@ def _downsample_itkwasm(
172172
]
173173
assert all(
174174
block_output.size[dim] == computed_size[dim]
175-
for dim in range(block_output.data.ndim)
175+
for dim in range(len(block_output.size))
176176
)
177177
for i in range(len(output_chunks)):
178178
output_chunks[i][-1] = block_output.data.shape[i]
@@ -181,7 +181,7 @@ def _downsample_itkwasm(
181181

182182
non_spatial_dims = [d for d in dims if d not in _spatial_dims]
183183
if "c" in non_spatial_dims and previous_image.dims[-1] == "c":
184-
non_spatial_dims.pop("c")
184+
non_spatial_dims.remove("c")
185185

186186
if output_chunks_start > 0:
187187
# We'll iterate over each index for the non-spatial dimensions, run the desired

ngff_zarr/methods/_support.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,25 @@ def _dim_scale_factors(dims, scale_factor, previous_dim_factors):
111111
result_scale_factors = {
112112
d: int(scale_factor[d] / previous_dim_factors[d]) for d in scale_factor
113113
}
114+
# if a dim is not in the scale_factors, add it with a scale factor of 1
115+
for d in dims:
116+
if d not in result_scale_factors:
117+
result_scale_factors[d] = 1
118+
114119
return result_scale_factors
115120

121+
116122
def _update_previous_dim_factors(scale_factor, spatial_dims, previous_dim_factors):
123+
previous_dim_factors = copy.copy(previous_dim_factors)
117124
if isinstance(scale_factor, int):
118-
previous_dim_factors = { d : scale_factor for d in spatial_dims }
125+
for d in spatial_dims:
126+
previous_dim_factors[d] = scale_factor
119127
else:
120-
previous_dim_factors = scale_factor
128+
for d in scale_factor:
129+
previous_dim_factors[d] = scale_factor[d]
121130
return previous_dim_factors
122131

132+
123133
def _align_chunks(previous_image, default_chunks, dim_factors):
124134
block_0_shape = [c[0] for c in previous_image.data.chunks]
125135

pixi.lock

Lines changed: 111 additions & 111 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"dask[array]",
3333
"importlib_resources",
3434
"itkwasm >= 1.0b183",
35-
"itkwasm-downsample >= 1.2.0",
35+
"itkwasm-downsample >= 1.7.1",
3636
"numpy",
3737
"platformdirs",
3838
"psutil; sys_platform != \"emscripten\"",

test/test_to_ngff_zarr_itk.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from ngff_zarr import Methods, to_multiscales
3+
from ngff_zarr import Methods, to_multiscales, to_ngff_image
44

55
from ._data import verify_against_baseline
66
import platform
@@ -20,6 +20,41 @@ def test_bin_shrink_isotropic_scale_factors(input_images):
2020
verify_against_baseline(dataset_name, baseline_name, multiscales)
2121

2222

23+
def test_bin_shrink_tzyxc():
24+
import dask.array as da
25+
26+
test_array = da.ones((96, 64, 64, 64, 2), chunks=(1, 64, 64, 64, 1), dtype="uint8")
27+
img = to_ngff_image(
28+
test_array,
29+
dims=list("tzyxc"),
30+
scale={
31+
"t": 100_000.0,
32+
"z": 0.98,
33+
"y": 0.98,
34+
"x": 0.98,
35+
"c": 1.0,
36+
},
37+
axes_units={
38+
"t": "millisecond",
39+
"z": "micrometer",
40+
"y": "micrometer",
41+
"x": "micrometer",
42+
},
43+
name="000x_000y_000z",
44+
)
45+
46+
# expect a ValueError
47+
try:
48+
to_multiscales(
49+
img,
50+
scale_factors=[{"z": 2, "y": 2, "x": 2}, {"z": 4, "y": 4, "x": 4}],
51+
method=Methods.ITK_BIN_SHRINK,
52+
)
53+
assert False, "Expected ValueError for non-spatial dimensions"
54+
except ValueError:
55+
pass
56+
57+
2358
@pytest.mark.skipif(
2459
platform.system() == "Linux" and platform.machine() == "aarch64",
2560
reason="Skipping on Linux ARM systems",

test/test_to_ngff_zarr_itkwasm.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,63 @@ def test_bin_shrink_tczyx():
107107
assert multiscales.images[1].data.shape[2] == 16
108108

109109

110+
def test_bin_shrink_tzyxc():
111+
import dask.array as da
112+
113+
test_array = da.ones(
114+
(96, 256, 256, 256, 2), chunks=(1, 128, 128, 128, 1), dtype="uint16"
115+
)
116+
img = to_ngff_image(
117+
test_array,
118+
dims=list("tzyxc"),
119+
scale={
120+
"t": 100_000.0,
121+
"z": 0.98,
122+
"y": 0.98,
123+
"x": 0.98,
124+
"c": 1.0,
125+
},
126+
axes_units={
127+
"t": "millisecond",
128+
"z": "micrometer",
129+
"y": "micrometer",
130+
"x": "micrometer",
131+
},
132+
name="000x_000y_000z",
133+
)
134+
# test_array = da.ones(
135+
# (64, 64, 64, 2), chunks=(64, 64, 64, 1), dtype="uint16"
136+
# )
137+
# img = to_ngff_image(
138+
# test_array,
139+
# dims=list("zyxc"),
140+
# scale={
141+
# "z": 0.98,
142+
# "y": 0.98,
143+
# "x": 0.98,
144+
# "c": 1.0,
145+
# },
146+
# axes_units={
147+
# "z": "micrometer",
148+
# "y": "micrometer",
149+
# "x": "micrometer",
150+
# },
151+
# name="000x_000y_000z",
152+
# )
153+
# from ngff_zarr import ngff_image_to_itk_image
154+
# itk_img = ngff_image_to_itk_image(img)
155+
# from itkwasm_image_io import write_image
156+
# write_image(itk_img, "test_bin_shrink_zyxc.iwi")
157+
# write_image(itk_img, "test_bin_shrink_zyxc.iwi.cbor")
158+
159+
multiscales = to_multiscales(
160+
img,
161+
scale_factors=[{"z": 2, "y": 2, "x": 2}, {"z": 4, "y": 4, "x": 4}],
162+
method=Methods.ITKWASM_BIN_SHRINK,
163+
)
164+
assert len(multiscales.images) == 3
165+
166+
110167
def test_bin_shrink_isotropic_scale_factors(input_images):
111168
dataset_name = "cthead1"
112169
image = input_images[dataset_name]

0 commit comments

Comments
 (0)