@@ -466,43 +466,45 @@ def merged_chunk_len_for_indexer(ia, c):
466
466
return (c // ia .step ) * ia .step
467
467
468
468
shape = idx .newshape (x .shape )
469
+
469
470
if shape == x .shape :
470
- # no op case
471
- return x
472
- dtype = x .dtype
473
- chunks = tuple (
474
- chunk_len_for_indexer (ia , c )
475
- for ia , c in zip (idx .args , x .chunksize )
476
- if not isinstance (ia , ndindex .Integer )
477
- )
471
+ # no op case (except possibly newaxis applied below)
472
+ out = x
473
+ else :
474
+ dtype = x .dtype
475
+ chunks = tuple (
476
+ chunk_len_for_indexer (ia , c )
477
+ for ia , c in zip (idx .args , x .chunksize )
478
+ if not isinstance (ia , ndindex .Integer )
479
+ )
478
480
479
- target_chunks = normalize_chunks (chunks , shape , dtype = dtype )
481
+ target_chunks = normalize_chunks (chunks , shape , dtype = dtype )
480
482
481
- # memory allocated by reading one chunk from input array
482
- # note that although the output chunk will overlap multiple input chunks, zarr will
483
- # read the chunks in series, reusing the buffer
484
- extra_projected_mem = x .chunkmem
483
+ # memory allocated by reading one chunk from input array
484
+ # note that although the output chunk will overlap multiple input chunks, zarr will
485
+ # read the chunks in series, reusing the buffer
486
+ extra_projected_mem = x .chunkmem
485
487
486
- out = map_direct (
487
- _read_index_chunk ,
488
- x ,
489
- shape = shape ,
490
- dtype = dtype ,
491
- chunks = target_chunks ,
492
- extra_projected_mem = extra_projected_mem ,
493
- target_chunks = target_chunks ,
494
- selection = selection ,
495
- )
488
+ out = map_direct (
489
+ _read_index_chunk ,
490
+ x ,
491
+ shape = shape ,
492
+ dtype = dtype ,
493
+ chunks = target_chunks ,
494
+ extra_projected_mem = extra_projected_mem ,
495
+ target_chunks = target_chunks ,
496
+ selection = selection ,
497
+ )
496
498
497
- # merge chunks for any dims with step > 1 so they are
498
- # the same size as the input (or slightly smaller due to rounding)
499
- merged_chunks = tuple (
500
- merged_chunk_len_for_indexer (ia , c )
501
- for ia , c in zip (idx .args , x .chunksize )
502
- if not isinstance (ia , ndindex .Integer )
503
- )
504
- if chunks != merged_chunks :
505
- out = merge_chunks (out , merged_chunks )
499
+ # merge chunks for any dims with step > 1 so they are
500
+ # the same size as the input (or slightly smaller due to rounding)
501
+ merged_chunks = tuple (
502
+ merged_chunk_len_for_indexer (ia , c )
503
+ for ia , c in zip (idx .args , x .chunksize )
504
+ if not isinstance (ia , ndindex .Integer )
505
+ )
506
+ if chunks != merged_chunks :
507
+ out = merge_chunks (out , merged_chunks )
506
508
507
509
for axis in where_newaxis :
508
510
from cubed .array_api .manipulation_functions import expand_dims
0 commit comments