11from typing import Tuple
2+ from itertools import product
23
34import numpy as np
4- from dask .array import concatenate , expand_dims , map_blocks , map_overlap , take
5+ from dask .array import concatenate , expand_dims , map_blocks , map_overlap , take , stack , from_array
56
67from ..ngff_image import NgffImage
78from ._support import (
@@ -181,9 +182,7 @@ def _downsample_itkwasm(
181182
182183 # Compute overlap for Gaussian blurring for all blocks
183184 is_vector = previous_image .dims [- 1 ] == "c"
184- block_0_image = itkwasm .image_from_array (
185- np .ones_like (block_0_input ), is_vector = is_vector
186- )
185+ block_0_image = itkwasm .image_from_array (np .ones_like (block_0_input ), is_vector = is_vector )
187186 input_spacing = [previous_image .scale [d ] for d in spatial_dims ]
188187 block_0_image .spacing = input_spacing
189188 input_origin = [previous_image .translation [d ] for d in spatial_dims ]
@@ -214,25 +213,18 @@ def _downsample_itkwasm(
214213 block_output .size [dim ] == computed_size [dim ]
215214 for dim in range (block_output .data .ndim )
216215 )
217- breakpoint ()
218216 output_chunks = list (previous_image .data .chunks )
219217 dims = list (previous_image .dims )
220218 output_chunks_start = 0
221219 while dims [output_chunks_start ] not in _spatial_dims :
222220 output_chunks_start += 1
223221 output_chunks = output_chunks [output_chunks_start :]
224- # if "t" in previous_image.dims:
225- # dims = list(previous_image.dims)
226- # t_index = dims.index("t")
227- # output_chunks.pop(t_index)
228222 for i , c in enumerate (output_chunks ):
229223 output_chunks [i ] = [
230224 block_output .data .shape [i ],
231225 ] * len (c )
232226 # Compute output size for block N-1
233- block_neg1_image = itkwasm .image_from_array (
234- np .ones_like (block_neg1_input ), is_vector = is_vector
235- )
227+ block_neg1_image = itkwasm .image_from_array (np .ones_like (block_neg1_input ), is_vector = is_vector )
236228 block_neg1_image .spacing = input_spacing
237229 block_neg1_image .origin = input_origin
238230 block_output = downsample_bin_shrink (
@@ -251,6 +243,92 @@ def _downsample_itkwasm(
251243 output_chunks [i ] = tuple (output_chunks [i ])
252244 output_chunks = tuple (output_chunks )
253245
246+ non_spatial_dims = [d for d in dims if d not in _spatial_dims ]
247+ if "c" in non_spatial_dims and dims [- 1 ] == "c" :
248+ non_spatial_dims .pop ("c" )
249+
250+ # We'll iterate over each index for the non-spatial dimensions, run the desired
251+ # map_overlap, and aggregate the outputs into a final result.
252+
253+ block_shape = [c [0 ] for c in previous_image .data .chunks ]
254+ # Determine the size for each non-spatial dimension
255+ non_spatial_shapes = [
256+ block_shape [dims .index (dim )] for dim in non_spatial_dims
257+ ]
258+
259+ # Collect results for each sub-block
260+ aggregated_blocks = []
261+ for idx in product (* (range (s ) for s in non_spatial_shapes )):
262+ # Build the slice object for indexing
263+ slice_obj = []
264+ non_spatial_index = 0
265+ for dim in dims :
266+ if dim in non_spatial_dims :
267+ # Take a single index (like "t=0,1,...") for the non-spatial dimension
268+ slice_obj .append (idx [non_spatial_index ])
269+ non_spatial_index += 1
270+ else :
271+ # Keep full slice for spatial/channel dims
272+ slice_obj .append (slice (None ))
273+
274+ # Extract the sub-block data for the chosen index from the non-spatial dims
275+ sub_block_data = previous_image .data [tuple (slice_obj )]
276+
277+ downscaled_sub_block = map_overlap (
278+ _itkwasm_blur_and_downsample ,
279+ sub_block_data ,
280+ shrink_factors = shrink_factors ,
281+ kernel_radius = kernel_radius ,
282+ smoothing = smoothing ,
283+ dtype = dtype ,
284+ depth = dict (enumerate (np .flip (kernel_radius ))), # overlap is in tzyx
285+ boundary = "nearest" ,
286+ trim = False , # Overlapped region is trimmed in blur_and_downsample to output size
287+ chunks = output_chunks ,
288+ )
289+ # sub_block_image = itkwasm.image_from_array(
290+ # sub_block_data,
291+ # is_vector=is_vector # or as needed for your pipeline
292+ # )
293+ # sub_block_image.spacing = input_spacing
294+ # sub_block_image.origin = input_origin
295+
296+ # # Run your map_overlap or other downsampling operation on the sub_block
297+ # # (e.g., downsample_bin_shrink, gaussian, etc.)
298+ # sub_block_output = downsample_bin_shrink(
299+ # sub_block_image,
300+ # shrink_factors,
301+ # information_only=False
302+ # )
303+
304+ # Collect the result for later aggregation
305+ aggregated_blocks .append (downscaled_sub_block )
306+ downscaled_array = da .empty (downscaled_sub_block .shape )
307+ blocks_dask = [from_array (block , chunks = block .shape ) for block in aggregated_blocks ]
308+ final_dask_array = stack (blocks_dask , axis = 0 )
309+
310+ # At this point you have a list (aggregated_blocks) of processed sub-blocks.
311+ # You can stitch/concat them back together along the non-spatial dimensions.
312+ # For example, you can shape them into a single array if desired:
313+ # (Rebuild the final data array in the same shape as the non-spatial dims + new spatial dims)
314+ //
315+ // final_result = np .empty (...)
316+ // fill final_result with aggregated_blocks in correct order
317+ //
318+ // The rest of the code can remain unchanged , or you can assign final_result to your final image data
319+ // ...existing code ...
320+ // ...existing code ...
321+ import dask .array as da
322+
323+ # Suppose we have a list of processed sub-blocks in aggregated_blocks
324+ # Each item is a NumPy array from sub_block_output.data
325+
326+ blocks_dask = [da .from_array (block , chunks = block .shape ) for block in aggregated_blocks ]
327+
328+ # Combine them along the first axis (or another axis as needed)
329+
330+ # Now final_dask_array is a single dask array representing all sub-blocks
331+ # ...existing code...
254332 if "t" in previous_image .dims :
255333 all_timepoints = []
256334 for timepoint in range (previous_image .data .shape [t_index ]):
0 commit comments