Skip to content

Commit 913cc06

Browse files
committed
PERF: itkwasm downsampling in zyx order
Support memory layout assumed in _large_image_serialization.
1 parent 03fe042 commit 913cc06

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

ngff_zarr/methods/_itkwasm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
_dim_scale_factors,
1313
_get_block,
1414
_spatial_dims,
15-
_spatial_dims_last,
15+
_spatial_dims_last_zyx,
1616
)
1717

1818
_image_dims: Tuple[str, str, str, str] = ("x", "y", "z", "t")
@@ -96,7 +96,7 @@ def _downsample_itkwasm(
9696
previous_dim_factors = dim_factors
9797
previous_image = _align_chunks(previous_image, default_chunks, dim_factors)
9898
# Operate on a contiguous spatial block
99-
previous_image = _spatial_dims_last(previous_image)
99+
previous_image = _spatial_dims_last_zyx(previous_image)
100100
if previous_image.dims != dims:
101101
transposed_dims = True
102102
reorder = [previous_image.dims.index(dim) for dim in dims]
@@ -270,7 +270,7 @@ def _downsample_itkwasm(
270270
out_chunks_list.append(1)
271271
downscaled_array = downscaled_array.rechunk(tuple(out_chunks_list))
272272

273-
# transpose back to original order if needed (_spatial_dims_last transposed the order)
273+
# transpose back to original order if needed (_spatial_dims_zyx transposed the order)
274274
# breakpoint()
275275
if transposed_dims:
276276
downscaled_array = downscaled_array.transpose(reorder)

ngff_zarr/methods/_support.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,44 @@ def _spatial_dims_last(ngff_image: NgffImage) -> NgffImage:
4242
return result
4343

4444

45+
def _spatial_dims_last_zyx(ngff_image: NgffImage) -> NgffImage:
46+
dims = list(ngff_image.dims)
47+
spatial_dims = [dim for dim in dims if dim in _spatial_dims]
48+
49+
# If spatial dimensions are already zyx, return the original image
50+
if spatial_dims == ["z", "y", "x"] or spatial_dims == ["y", "x"]:
51+
dims_spatial_channel = len(spatial_dims)
52+
if dims[-1] == "c":
53+
dims_spatial_channel += 1
54+
55+
# If spatial dimensions are already last (and 'c' can be last), return the original image
56+
if all(dim in dims[-dims_spatial_channel:] for dim in spatial_dims + ["c"]):
57+
return ngff_image
58+
59+
# Move spatial dimensions to the end, keeping 'c' as the last pre-spatial dimension if present
60+
non_spatial_dims = [dim for dim in dims if dim not in _spatial_dims]
61+
new_spatial_dims = ["z", "y", "x"][-len(spatial_dims) :]
62+
if "c" in non_spatial_dims:
63+
non_spatial_dims.remove("c")
64+
new_dims = non_spatial_dims + ["c"] + new_spatial_dims
65+
else:
66+
new_dims = non_spatial_dims + new_spatial_dims
67+
68+
new_order = [dims.index(dim) for dim in new_dims]
69+
70+
if tuple(new_dims) == tuple(ngff_image.dims):
71+
return ngff_image
72+
73+
# Reorder the data array
74+
reordered_data = ngff_image.data.transpose(new_order)
75+
76+
result = copy.copy(ngff_image)
77+
result.data = reordered_data
78+
result.dims = tuple(new_dims)
79+
80+
return result
81+
82+
4583
def _channel_dim_last(ngff_image: NgffImage) -> NgffImage:
4684
if "c" not in ngff_image.dims or ngff_image.dims[-1] == "c":
4785
return ngff_image

0 commit comments

Comments
 (0)