26
26
from cubed .spec import spec_from_config
27
27
from cubed .storage .backend import open_backend_array
28
28
from cubed .types import T_RegularChunks , T_Shape
29
- from cubed .utils import (
30
- _concatenate2 ,
31
- array_memory ,
32
- array_size ,
33
- get_item ,
34
- offset_to_block_id ,
35
- to_chunksize ,
36
- )
29
+ from cubed .utils import _concatenate2 , array_memory , array_size , get_item
30
+ from cubed .utils import numblocks as compute_numblocks
31
+ from cubed .utils import offset_to_block_id , to_chunksize
37
32
from cubed .vendor .dask .array .core import normalize_chunks
38
33
from cubed .vendor .dask .array .utils import validate_axis
39
34
from cubed .vendor .dask .blockwise import broadcast_dimensions , lol_product
@@ -342,6 +337,77 @@ def general_blockwise(
342
337
target_paths = None ,
343
338
extra_func_kwargs = None ,
344
339
** kwargs ,
340
+ ) -> Union ["Array" , Tuple ["Array" , ...]]:
341
+ if has_keyword (func , "block_id" ):
342
+ from cubed .array_api .creation_functions import offsets_virtual_array
343
+
344
+ # Create an array of index offsets with the same chunk structure as the args,
345
+ # which we convert to block ids (chunk coordinates) later.
346
+ array0 = arrays [0 ]
347
+ # note that primitive general_blockwise checks that all chunkss have same numblocks
348
+ numblocks = compute_numblocks (chunkss [0 ])
349
+ offsets = offsets_virtual_array (numblocks , array0 .spec )
350
+ new_arrays = arrays + (offsets ,)
351
+
352
+ def key_function_with_offset (key_function ):
353
+ def wrap (out_key ):
354
+ out_coords = out_key [1 :]
355
+ offset_in_key = ((offsets .name ,) + out_coords ,)
356
+ return key_function (out_key ) + offset_in_key
357
+
358
+ return wrap
359
+
360
+ def func_with_block_id (func ):
361
+ def wrap (* a , ** kw ):
362
+ offset = int (a [- 1 ]) # convert from 0-d array
363
+ block_id = offset_to_block_id (offset , numblocks )
364
+ return func (* a [:- 1 ], block_id = block_id , ** kw )
365
+
366
+ return wrap
367
+
368
+ num_input_blocks = kwargs .pop ("num_input_blocks" , None )
369
+ if num_input_blocks is not None :
370
+ num_input_blocks = num_input_blocks + (1 ,) # for offsets array
371
+
372
+ return _general_blockwise (
373
+ func_with_block_id (func ),
374
+ key_function_with_offset (key_function ),
375
+ * new_arrays ,
376
+ shapes = shapes ,
377
+ dtypes = dtypes ,
378
+ chunkss = chunkss ,
379
+ target_stores = target_stores ,
380
+ target_paths = target_paths ,
381
+ extra_func_kwargs = extra_func_kwargs ,
382
+ num_input_blocks = num_input_blocks ,
383
+ ** kwargs ,
384
+ )
385
+
386
+ return _general_blockwise (
387
+ func ,
388
+ key_function ,
389
+ * arrays ,
390
+ shapes = shapes ,
391
+ dtypes = dtypes ,
392
+ chunkss = chunkss ,
393
+ target_stores = target_stores ,
394
+ target_paths = target_paths ,
395
+ extra_func_kwargs = extra_func_kwargs ,
396
+ ** kwargs ,
397
+ )
398
+
399
+
400
+ def _general_blockwise (
401
+ func ,
402
+ key_function ,
403
+ * arrays ,
404
+ shapes ,
405
+ dtypes ,
406
+ chunkss ,
407
+ target_stores = None ,
408
+ target_paths = None ,
409
+ extra_func_kwargs = None ,
410
+ ** kwargs ,
345
411
) -> Union ["Array" , Tuple ["Array" , ...]]:
346
412
assert len (arrays ) > 0
347
413
@@ -504,12 +570,6 @@ def merged_chunk_len_for_indexer(ia, c):
504
570
if _is_chunk_aligned_selection (idx ):
505
571
# use general_blockwise, which allows more opportunities for optimization than map_direct
506
572
507
- from cubed .array_api .creation_functions import offsets_virtual_array
508
-
509
- # general_blockwise doesn't support block_id, so emulate it ourselves
510
- numblocks = tuple (map (len , target_chunks ))
511
- offsets = offsets_virtual_array (numblocks , x .spec )
512
-
513
573
def key_function (out_key ):
514
574
out_coords = out_key [1 :]
515
575
@@ -521,24 +581,17 @@ def key_function(out_key):
521
581
in_sel , x .zarray_maybe_lazy .shape , x .zarray_maybe_lazy .chunks
522
582
)
523
583
524
- offset_in_key = ((offsets .name ,) + out_coords ,)
525
- return (
526
- tuple ((x .name ,) + chunk_coords for (chunk_coords , _ , _ ) in indexer )
527
- + offset_in_key
584
+ return tuple (
585
+ (x .name ,) + chunk_coords for (chunk_coords , _ , _ ) in indexer
528
586
)
529
587
530
- # since selection is chunk-aligned, we know that we only read one block of x
531
- num_input_blocks = (1 , 1 ) # x, offsets
532
-
533
588
out = general_blockwise (
534
589
_assemble_index_chunk ,
535
590
key_function ,
536
591
x ,
537
- offsets ,
538
592
shapes = [shape ],
539
593
dtypes = [x .dtype ],
540
594
chunkss = [target_chunks ],
541
- num_input_blocks = num_input_blocks ,
542
595
target_chunks = target_chunks ,
543
596
selection = selection ,
544
597
in_shape = x .shape ,
@@ -622,14 +675,8 @@ def _assemble_index_chunk(
622
675
selection = None ,
623
676
in_shape = None ,
624
677
in_chunksize = None ,
678
+ block_id = None ,
625
679
):
626
- # last array contains the offset for the block_id
627
- offset = int (arrs [- 1 ]) # convert from 0-d array
628
- numblocks = tuple (map (len , target_chunks ))
629
- block_id = offset_to_block_id (offset , numblocks )
630
-
631
- arrs = arrs [:- 1 ] # drop offset array
632
-
633
680
# compute the selection on x required to get the relevant chunk for out_coords
634
681
out_coords = block_id
635
682
in_sel = _target_chunk_selection (target_chunks , out_coords , selection )
0 commit comments