Skip to content

Commit 2d3440f

Browse files
authored
Support block_id for general_blockwise functions (#593)
* Ensure numblocks match for multiple outputs in general blockwise * Support block_id for general_blockwise
1 parent eed4e30 commit 2d3440f

File tree

4 files changed

+134
-38
lines changed

4 files changed

+134
-38
lines changed

cubed/core/ops.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,9 @@
2626
from cubed.spec import spec_from_config
2727
from cubed.storage.backend import open_backend_array
2828
from cubed.types import T_RegularChunks, T_Shape
29-
from cubed.utils import (
30-
_concatenate2,
31-
array_memory,
32-
array_size,
33-
get_item,
34-
offset_to_block_id,
35-
to_chunksize,
36-
)
29+
from cubed.utils import _concatenate2, array_memory, array_size, get_item
30+
from cubed.utils import numblocks as compute_numblocks
31+
from cubed.utils import offset_to_block_id, to_chunksize
3732
from cubed.vendor.dask.array.core import normalize_chunks
3833
from cubed.vendor.dask.array.utils import validate_axis
3934
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
@@ -342,6 +337,77 @@ def general_blockwise(
342337
target_paths=None,
343338
extra_func_kwargs=None,
344339
**kwargs,
340+
) -> Union["Array", Tuple["Array", ...]]:
341+
if has_keyword(func, "block_id"):
342+
from cubed.array_api.creation_functions import offsets_virtual_array
343+
344+
# Create an array of index offsets with the same chunk structure as the args,
345+
# which we convert to block ids (chunk coordinates) later.
346+
array0 = arrays[0]
347+
# note that primitive general_blockwise checks that all chunkss have same numblocks
348+
numblocks = compute_numblocks(chunkss[0])
349+
offsets = offsets_virtual_array(numblocks, array0.spec)
350+
new_arrays = arrays + (offsets,)
351+
352+
def key_function_with_offset(key_function):
353+
def wrap(out_key):
354+
out_coords = out_key[1:]
355+
offset_in_key = ((offsets.name,) + out_coords,)
356+
return key_function(out_key) + offset_in_key
357+
358+
return wrap
359+
360+
def func_with_block_id(func):
361+
def wrap(*a, **kw):
362+
offset = int(a[-1]) # convert from 0-d array
363+
block_id = offset_to_block_id(offset, numblocks)
364+
return func(*a[:-1], block_id=block_id, **kw)
365+
366+
return wrap
367+
368+
num_input_blocks = kwargs.pop("num_input_blocks", None)
369+
if num_input_blocks is not None:
370+
num_input_blocks = num_input_blocks + (1,) # for offsets array
371+
372+
return _general_blockwise(
373+
func_with_block_id(func),
374+
key_function_with_offset(key_function),
375+
*new_arrays,
376+
shapes=shapes,
377+
dtypes=dtypes,
378+
chunkss=chunkss,
379+
target_stores=target_stores,
380+
target_paths=target_paths,
381+
extra_func_kwargs=extra_func_kwargs,
382+
num_input_blocks=num_input_blocks,
383+
**kwargs,
384+
)
385+
386+
return _general_blockwise(
387+
func,
388+
key_function,
389+
*arrays,
390+
shapes=shapes,
391+
dtypes=dtypes,
392+
chunkss=chunkss,
393+
target_stores=target_stores,
394+
target_paths=target_paths,
395+
extra_func_kwargs=extra_func_kwargs,
396+
**kwargs,
397+
)
398+
399+
400+
def _general_blockwise(
401+
func,
402+
key_function,
403+
*arrays,
404+
shapes,
405+
dtypes,
406+
chunkss,
407+
target_stores=None,
408+
target_paths=None,
409+
extra_func_kwargs=None,
410+
**kwargs,
345411
) -> Union["Array", Tuple["Array", ...]]:
346412
assert len(arrays) > 0
347413

@@ -504,12 +570,6 @@ def merged_chunk_len_for_indexer(ia, c):
504570
if _is_chunk_aligned_selection(idx):
505571
# use general_blockwise, which allows more opportunities for optimization than map_direct
506572

507-
from cubed.array_api.creation_functions import offsets_virtual_array
508-
509-
# general_blockwise doesn't support block_id, so emulate it ourselves
510-
numblocks = tuple(map(len, target_chunks))
511-
offsets = offsets_virtual_array(numblocks, x.spec)
512-
513573
def key_function(out_key):
514574
out_coords = out_key[1:]
515575

@@ -521,24 +581,17 @@ def key_function(out_key):
521581
in_sel, x.zarray_maybe_lazy.shape, x.zarray_maybe_lazy.chunks
522582
)
523583

524-
offset_in_key = ((offsets.name,) + out_coords,)
525-
return (
526-
tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)
527-
+ offset_in_key
584+
return tuple(
585+
(x.name,) + chunk_coords for (chunk_coords, _, _) in indexer
528586
)
529587

530-
# since selection is chunk-aligned, we know that we only read one block of x
531-
num_input_blocks = (1, 1) # x, offsets
532-
533588
out = general_blockwise(
534589
_assemble_index_chunk,
535590
key_function,
536591
x,
537-
offsets,
538592
shapes=[shape],
539593
dtypes=[x.dtype],
540594
chunkss=[target_chunks],
541-
num_input_blocks=num_input_blocks,
542595
target_chunks=target_chunks,
543596
selection=selection,
544597
in_shape=x.shape,
@@ -622,14 +675,8 @@ def _assemble_index_chunk(
622675
selection=None,
623676
in_shape=None,
624677
in_chunksize=None,
678+
block_id=None,
625679
):
626-
# last array contains the offset for the block_id
627-
offset = int(arrs[-1]) # convert from 0-d array
628-
numblocks = tuple(map(len, target_chunks))
629-
block_id = offset_to_block_id(offset, numblocks)
630-
631-
arrs = arrs[:-1] # drop offset array
632-
633680
# compute the selection on x required to get the relevant chunk for out_coords
634681
out_coords = block_id
635682
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)

cubed/primitive/blockwise.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
from cubed.runtime.types import CubedPipeline
1919
from cubed.storage.zarr import T_ZarrArray, lazy_zarr_array
2020
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
21-
from cubed.utils import (
22-
array_memory,
23-
chunk_memory,
24-
get_item,
25-
map_nested,
26-
split_into,
27-
to_chunksize,
28-
)
21+
from cubed.utils import array_memory, chunk_memory, get_item, map_nested
22+
from cubed.utils import numblocks as compute_numblocks
23+
from cubed.utils import split_into, to_chunksize
2924
from cubed.vendor.dask.array.core import normalize_chunks
3025
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
3126
from cubed.vendor.dask.core import flatten
@@ -261,6 +256,8 @@ def general_blockwise(
261256
"""A more general form of ``blockwise`` that uses a function to specify the block
262257
mapping, rather than an index notation, and which supports multiple outputs.
263258
259+
For multiple outputs, all output arrays must have matching numblocks.
260+
264261
Parameters
265262
----------
266263
func : callable
@@ -308,9 +305,18 @@ def general_blockwise(
308305
output_chunk_memory = 0
309306
target_array = []
310307

308+
numblocks0 = None
311309
for i, target_store in enumerate(target_stores):
312310
chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i])
313311
chunksize = to_chunksize(chunks_normal)
312+
if numblocks0 is None:
313+
numblocks0 = compute_numblocks(chunks_normal)
314+
else:
315+
numblocks = compute_numblocks(chunks_normal)
316+
if numblocks != numblocks0:
317+
raise ValueError(
318+
f"All outputs must have matching number of blocks in each dimension. Chunks specified: {chunkss}"
319+
)
314320
if isinstance(target_store, zarr.Array):
315321
ta = target_store
316322
else:

cubed/tests/primitive/test_blockwise.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,45 @@ def block_function(out_key):
285285
assert_array_equal(res2[:], -np.sqrt(input))
286286

287287

288+
def test_blockwise_multiple_outputs_fails_different_numblocks(tmp_path):
289+
source = create_zarr(
290+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
291+
dtype=int,
292+
chunks=(2, 2),
293+
store=tmp_path / "source.zarr",
294+
)
295+
allowed_mem = 1000
296+
target_store1 = tmp_path / "target1.zarr"
297+
target_store2 = tmp_path / "target2.zarr"
298+
299+
in_name = "x"
300+
301+
def sqrts(x):
302+
yield np.sqrt(x)
303+
yield -np.sqrt(x)
304+
305+
def block_function(out_key):
306+
out_coords = out_key[1:]
307+
return ((in_name, *out_coords),)
308+
309+
with pytest.raises(
310+
ValueError,
311+
match="All outputs must have matching number of blocks in each dimension",
312+
):
313+
general_blockwise(
314+
sqrts,
315+
block_function,
316+
source,
317+
allowed_mem=allowed_mem,
318+
reserved_mem=0,
319+
target_stores=[target_store1, target_store2],
320+
shapes=[(3, 3), (3, 3)],
321+
dtypes=[float, float],
322+
chunkss=[(2, 2), (4, 2)], # numblocks differ
323+
in_names=[in_name],
324+
)
325+
326+
288327
def test_make_blockwise_key_function_map():
289328
func = lambda x: 0
290329

cubed/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def to_chunksize(chunkset: T_RectangularChunks) -> T_RegularChunks:
147147
return tuple(max(c[0], 1) for c in chunkset)
148148

149149

150+
def numblocks(chunks: T_RectangularChunks) -> Tuple[int, ...]:
151+
return tuple(map(len, chunks))
152+
153+
150154
@dataclass
151155
class StackSummary:
152156
"""Like Python's ``FrameSummary``, but with module information."""

0 commit comments

Comments
 (0)