Skip to content

Commit 1d575fe

Browse files
committed
WIP: iter
1 parent e6c77e1 commit 1d575fe

File tree

1 file changed

+90
-12
lines changed

1 file changed

+90
-12
lines changed

ngff_zarr/methods/_itkwasm.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Tuple
2+
from itertools import product
23

34
import numpy as np
4-
from dask.array import concatenate, expand_dims, map_blocks, map_overlap, take
5+
from dask.array import concatenate, expand_dims, map_blocks, map_overlap, take, stack, from_array
56

67
from ..ngff_image import NgffImage
78
from ._support import (
@@ -181,9 +182,7 @@ def _downsample_itkwasm(
181182

182183
# Compute overlap for Gaussian blurring for all blocks
183184
is_vector = previous_image.dims[-1] == "c"
184-
block_0_image = itkwasm.image_from_array(
185-
np.ones_like(block_0_input), is_vector=is_vector
186-
)
185+
block_0_image = itkwasm.image_from_array(np.ones_like(block_0_input), is_vector=is_vector)
187186
input_spacing = [previous_image.scale[d] for d in spatial_dims]
188187
block_0_image.spacing = input_spacing
189188
input_origin = [previous_image.translation[d] for d in spatial_dims]
@@ -214,25 +213,18 @@ def _downsample_itkwasm(
214213
block_output.size[dim] == computed_size[dim]
215214
for dim in range(block_output.data.ndim)
216215
)
217-
breakpoint()
218216
output_chunks = list(previous_image.data.chunks)
219217
dims = list(previous_image.dims)
220218
output_chunks_start = 0
221219
while dims[output_chunks_start] not in _spatial_dims:
222220
output_chunks_start += 1
223221
output_chunks = output_chunks[output_chunks_start:]
224-
# if "t" in previous_image.dims:
225-
# dims = list(previous_image.dims)
226-
# t_index = dims.index("t")
227-
# output_chunks.pop(t_index)
228222
for i, c in enumerate(output_chunks):
229223
output_chunks[i] = [
230224
block_output.data.shape[i],
231225
] * len(c)
232226
# Compute output size for block N-1
233-
block_neg1_image = itkwasm.image_from_array(
234-
np.ones_like(block_neg1_input), is_vector=is_vector
235-
)
227+
block_neg1_image = itkwasm.image_from_array(np.ones_like(block_neg1_input), is_vector=is_vector)
236228
block_neg1_image.spacing = input_spacing
237229
block_neg1_image.origin = input_origin
238230
block_output = downsample_bin_shrink(
@@ -251,6 +243,92 @@ def _downsample_itkwasm(
251243
output_chunks[i] = tuple(output_chunks[i])
252244
output_chunks = tuple(output_chunks)
253245

246+
non_spatial_dims = [d for d in dims if d not in _spatial_dims]
247+
if "c" in non_spatial_dims and dims[-1] == "c":
248+
non_spatial_dims.pop("c")
249+
250+
# We'll iterate over each index for the non-spatial dimensions, run the desired
251+
# map_overlap, and aggregate the outputs into a final result.
252+
253+
block_shape = [c[0] for c in previous_image.data.chunks]
254+
# Determine the size for each non-spatial dimension
255+
non_spatial_shapes = [
256+
block_shape[dims.index(dim)] for dim in non_spatial_dims
257+
]
258+
259+
# Collect results for each sub-block
260+
aggregated_blocks = []
261+
for idx in product(*(range(s) for s in non_spatial_shapes)):
262+
# Build the slice object for indexing
263+
slice_obj = []
264+
non_spatial_index = 0
265+
for dim in dims:
266+
if dim in non_spatial_dims:
267+
# Take a single index (like "t=0,1,...") for the non-spatial dimension
268+
slice_obj.append(idx[non_spatial_index])
269+
non_spatial_index += 1
270+
else:
271+
# Keep full slice for spatial/channel dims
272+
slice_obj.append(slice(None))
273+
274+
# Extract the sub-block data for the chosen index from the non-spatial dims
275+
sub_block_data = previous_image.data[tuple(slice_obj)]
276+
277+
downscaled_sub_block = map_overlap(
278+
_itkwasm_blur_and_downsample,
279+
sub_block_data,
280+
shrink_factors=shrink_factors,
281+
kernel_radius=kernel_radius,
282+
smoothing=smoothing,
283+
dtype=dtype,
284+
depth=dict(enumerate(np.flip(kernel_radius))), # overlap is in tzyx
285+
boundary="nearest",
286+
trim=False, # Overlapped region is trimmed in blur_and_downsample to output size
287+
chunks=output_chunks,
288+
)
289+
# sub_block_image = itkwasm.image_from_array(
290+
# sub_block_data,
291+
# is_vector=is_vector # or as needed for your pipeline
292+
# )
293+
# sub_block_image.spacing = input_spacing
294+
# sub_block_image.origin = input_origin
295+
296+
# # Run your map_overlap or other downsampling operation on the sub_block
297+
# # (e.g., downsample_bin_shrink, gaussian, etc.)
298+
# sub_block_output = downsample_bin_shrink(
299+
# sub_block_image,
300+
# shrink_factors,
301+
# information_only=False
302+
# )
303+
304+
# Collect the result for later aggregation
305+
aggregated_blocks.append(downscaled_sub_block)
306+
downscaled_array = da.empty(downscaled_sub_block.shape)
307+
blocks_dask = [from_array(block, chunks=block.shape) for block in aggregated_blocks]
308+
final_dask_array = stack(blocks_dask, axis=0)
309+
310+
# At this point you have a list (aggregated_blocks) of processed sub-blocks.
311+
# You can stitch/concat them back together along the non-spatial dimensions.
312+
# For example, you can shape them into a single array if desired:
313+
# (Rebuild the final data array in the same shape as the non-spatial dims + new spatial dims)
314+
//
315+
// final_result = np.empty(...)
316+
// fill final_result with aggregated_blocks in correct order
317+
//
318+
// The rest of the code can remain unchanged, or you can assign final_result to your final image data
319+
// ...existing code...
320+
// ...existing code...
321+
import dask.array as da
322+
323+
# Suppose we have a list of processed sub-blocks in aggregated_blocks
324+
# Each item is a NumPy array from sub_block_output.data
325+
326+
blocks_dask = [da.from_array(block, chunks=block.shape) for block in aggregated_blocks]
327+
328+
# Combine them along the first axis (or another axis as needed)
329+
330+
# Now final_dask_array is a single dask array representing all sub-blocks
331+
# ...existing code...
254332
if "t" in previous_image.dims:
255333
all_timepoints = []
256334
for timepoint in range(previous_image.data.shape[t_index]):

0 commit comments

Comments
 (0)