Skip to content

Commit 8ca25f1

Browse files
tomwhitethodson-usgs
authored andcommitted
Fix bug where newaxis with full slices doesn't add new axes (#559)
1 parent d6c2bb0 commit 8ca25f1

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

cubed/core/ops.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -466,43 +466,45 @@ def merged_chunk_len_for_indexer(ia, c):
466466
return (c // ia.step) * ia.step
467467

468468
shape = idx.newshape(x.shape)
469+
469470
if shape == x.shape:
470-
# no op case
471-
return x
472-
dtype = x.dtype
473-
chunks = tuple(
474-
chunk_len_for_indexer(ia, c)
475-
for ia, c in zip(idx.args, x.chunksize)
476-
if not isinstance(ia, ndindex.Integer)
477-
)
471+
# no op case (except possibly newaxis applied below)
472+
out = x
473+
else:
474+
dtype = x.dtype
475+
chunks = tuple(
476+
chunk_len_for_indexer(ia, c)
477+
for ia, c in zip(idx.args, x.chunksize)
478+
if not isinstance(ia, ndindex.Integer)
479+
)
478480

479-
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
481+
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
480482

481-
# memory allocated by reading one chunk from input array
482-
# note that although the output chunk will overlap multiple input chunks, zarr will
483-
# read the chunks in series, reusing the buffer
484-
extra_projected_mem = x.chunkmem
483+
# memory allocated by reading one chunk from input array
484+
# note that although the output chunk will overlap multiple input chunks, zarr will
485+
# read the chunks in series, reusing the buffer
486+
extra_projected_mem = x.chunkmem
485487

486-
out = map_direct(
487-
_read_index_chunk,
488-
x,
489-
shape=shape,
490-
dtype=dtype,
491-
chunks=target_chunks,
492-
extra_projected_mem=extra_projected_mem,
493-
target_chunks=target_chunks,
494-
selection=selection,
495-
)
488+
out = map_direct(
489+
_read_index_chunk,
490+
x,
491+
shape=shape,
492+
dtype=dtype,
493+
chunks=target_chunks,
494+
extra_projected_mem=extra_projected_mem,
495+
target_chunks=target_chunks,
496+
selection=selection,
497+
)
496498

497-
# merge chunks for any dims with step > 1 so they are
498-
# the same size as the input (or slightly smaller due to rounding)
499-
merged_chunks = tuple(
500-
merged_chunk_len_for_indexer(ia, c)
501-
for ia, c in zip(idx.args, x.chunksize)
502-
if not isinstance(ia, ndindex.Integer)
503-
)
504-
if chunks != merged_chunks:
505-
out = merge_chunks(out, merged_chunks)
499+
# merge chunks for any dims with step > 1 so they are
500+
# the same size as the input (or slightly smaller due to rounding)
501+
merged_chunks = tuple(
502+
merged_chunk_len_for_indexer(ia, c)
503+
for ia, c in zip(idx.args, x.chunksize)
504+
if not isinstance(ia, ndindex.Integer)
505+
)
506+
if chunks != merged_chunks:
507+
out = merge_chunks(out, merged_chunks)
506508

507509
for axis in where_newaxis:
508510
from cubed.array_api.manipulation_functions import expand_dims

cubed/tests/test_indexing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def spec(tmp_path):
2020
[6, 7, 2, 9, 10],
2121
([6, 7, 2, 9, 10], xp.newaxis),
2222
(xp.newaxis, [6, 7, 2, 9, 10]),
23+
(slice(None), xp.newaxis),
24+
(xp.newaxis, slice(None)),
2325
],
2426
)
2527
def test_int_array_index_1d(spec, ind):
@@ -36,6 +38,9 @@ def test_int_array_index_1d(spec, ind):
3638
(xp.newaxis, slice(None), [2, 1]),
3739
(slice(None), xp.newaxis, [2, 1]),
3840
(slice(None), [2, 1], xp.newaxis),
41+
(xp.newaxis, slice(None), slice(None)),
42+
(slice(None), xp.newaxis, slice(None)),
43+
(slice(None), slice(None), xp.newaxis),
3944
],
4045
)
4146
def test_int_array_index_2d(spec, ind):

0 commit comments

Comments
 (0)