@@ -188,7 +188,7 @@ def _downsample_itkwasm(
188188 non_spatial_shapes = previous_image .data .shape [:output_chunks_start ]
189189
190190 # Collect results for each sub-block
191- aggregated_blocks = {}
191+ aggregated_blocks = []
192192 for idx in product (* (range (s ) for s in non_spatial_shapes )):
193193 # Build the slice object for indexing
194194 slice_obj = []
@@ -231,11 +231,24 @@ def _downsample_itkwasm(
231231 trim = False , # Overlapped region is trimmed in blur_and_downsample to output size
232232 chunks = output_chunks ,
233233 )
234- aggregated_blocks [ slice_obj ] = downscaled_sub_block
234+ aggregated_blocks . append ( downscaled_sub_block )
235235 downscaled_array_shape = non_spatial_shapes + downscaled_sub_block .shape
236236 downscaled_array = dask .array .empty (downscaled_array_shape , dtype = dtype )
237- for slice_obj , block in aggregated_blocks .items ():
238- downscaled_array [slice_obj ] = block
237+ for downscaled_sub_block in aggregated_blocks :
238+ # Build the slice object for indexing
239+ slice_obj = []
240+ non_spatial_index = 0
241+ for dim in previous_image .dims :
242+ if dim in non_spatial_dims :
243+ # Take a single index (like "t=0,1,...") for the non-spatial dimension
244+ slice_obj .append (idx [non_spatial_index ])
245+ non_spatial_index += 1
246+ else :
247+ # Keep full slice for spatial/channel dims
248+ slice_obj .append (slice (None ))
249+
250+ slice_obj = tuple (slice_obj )
251+ downscaled_array [slice_obj ] = downscaled_sub_block
239252 else :
240253 data = previous_image .data
241254 if smoothing == "bin_shrink" :
0 commit comments