Skip to content

Commit 79d6969

Browse files
authored
Fix bug with expand dims of a scalar array (#103)
* simple function to create manifestarray in tests * test to expose bug with broadcast_to for scalars * fix usage of outdated attribute name * fix bug by special-casing broadcasting of scalar arrays
1 parent 0f37222 commit 79d6969

File tree

4 files changed

+96
-7
lines changed

4 files changed

+96
-7
lines changed

virtualizarr/manifests/array_api.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def concatenate(
112112

113113
# Ensure we handle axis being passed as a negative integer
114114
first_arr = arrays[0]
115-
axis = axis % first_arr.ndim
115+
if axis < 0:
116+
axis = axis % first_arr.ndim
116117

117118
arr_shapes = [arr.shape for arr in arrays]
118119
_check_same_shapes_except_on_concat_axis(arr_shapes, axis)
@@ -154,6 +155,7 @@ def _check_same_ndims(ndims: list[int]) -> None:
154155

155156
def _check_same_shapes_except_on_concat_axis(shapes: list[tuple[int, ...]], axis: int):
156157
"""Check that shapes are compatible for concatenation"""
158+
157159
shapes_without_concat_axis = [
158160
_remove_element_at_position(shape, axis) for shape in shapes
159161
]
@@ -198,7 +200,8 @@ def stack(
198200

199201
# Ensure we handle axis being passed as a negative integer
200202
first_arr = arrays[0]
201-
axis = axis % first_arr.ndim
203+
if axis < 0:
204+
axis = axis % first_arr.ndim
202205

203206
# find what new array shape must be
204207
length_along_new_stacked_axis = len(arrays)
@@ -267,8 +270,13 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
267270
if d == d_requested:
268271
pass
269272
elif d is None:
270-
# stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
271-
result = stack([result] * d_requested, axis=0)
273+
if result.shape == ():
274+
# scalars are a special case because their manifests already have a chunk key with one dimension
275+
# see https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
276+
result = _broadcast_scalar(result, new_axis_length=d_requested)
277+
else:
278+
# stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
279+
result = stack([result] * d_requested, axis=0)
272280
elif d == 1:
273281
# concatenate same array upon itself d_requested number of times along existing axis
274282
result = concatenate([result] * d_requested, axis=axis)
@@ -280,6 +288,41 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
280288
return result
281289

282290

291+
def _broadcast_scalar(x: "ManifestArray", new_axis_length: int) -> "ManifestArray":
292+
"""
293+
Add an axis to a scalar ManifestArray, but without adding a new axis to the keys of the chunk manifest.
294+
295+
This is not the same as concatenation, because there is no existing axis along which to concatenate.
296+
It's also not the same as stacking, because we don't want to insert a new axis into the chunk keys.
297+
298+
Scalars are a special case because their manifests still have a chunk key with one dimension.
299+
See https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
300+
"""
301+
302+
from .array import ManifestArray
303+
304+
new_shape = (new_axis_length,)
305+
new_chunks = (new_axis_length,)
306+
307+
concatenated_manifest = concat_manifests(
308+
[x.manifest] * new_axis_length,
309+
axis=0,
310+
)
311+
312+
new_zarray = ZArray(
313+
chunks=new_chunks,
314+
compressor=x.zarray.compressor,
315+
dtype=x.dtype,
316+
fill_value=x.zarray.fill_value,
317+
filters=x.zarray.filters,
318+
shape=new_shape,
319+
order=x.zarray.order,
320+
zarr_format=x.zarray.zarr_format,
321+
)
322+
323+
return ManifestArray(chunkmanifest=concatenated_manifest, zarray=new_zarray)
324+
325+
283326
# TODO broadcast_arrays, squeeze, permute_dims
284327

285328

virtualizarr/manifests/manifest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ def __repr__(self) -> str:
9999
return f"ChunkManifest<shape={self.shape_chunk_grid}>"
100100

101101
def __getitem__(self, key: ChunkKey) -> ChunkEntry:
102-
return self.chunks[key]
102+
return self.entries[key]
103103

104104
def __iter__(self) -> Iterator[ChunkKey]:
105-
return iter(self.chunks.keys())
105+
return iter(self.entries.keys())
106106

107107
def __len__(self) -> int:
108-
return len(self.chunks)
108+
return len(self.entries)
109109

110110
def dict(self) -> dict[str, dict[str, Union[str, int]]]:
111111
"""Converts the entire manifest to a nested dictionary."""

virtualizarr/tests/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
from virtualizarr.manifests import ChunkEntry, ChunkManifest, ManifestArray
4+
from virtualizarr.zarr import ZArray
5+
6+
7+
def create_manifestarray(
8+
shape: tuple[int, ...], chunks: tuple[int, ...]
9+
) -> ManifestArray:
10+
"""
11+
Create an example ManifestArray with sensible defaults.
12+
"""
13+
14+
zarray = ZArray(
15+
chunks=chunks,
16+
compressor="zlib",
17+
dtype=np.dtype("float32"),
18+
fill_value=0.0, # TODO change this to NaN?
19+
filters=None,
20+
order="C",
21+
shape=shape,
22+
zarr_format=2,
23+
)
24+
25+
if shape != ():
26+
raise NotImplementedError(
27+
"Only generation of array representing a single scalar currently supported"
28+
)
29+
30+
# TODO generalize this
31+
chunkmanifest = ChunkManifest(
32+
entries={"0": ChunkEntry(path="scalar.nc", offset=6144, length=48)}
33+
)
34+
35+
return ManifestArray(chunkmanifest=chunkmanifest, zarray=zarray)

virtualizarr/tests/test_manifests/test_array.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from virtualizarr.manifests import ChunkManifest, ManifestArray
5+
from virtualizarr.tests import create_manifestarray
56
from virtualizarr.zarr import ZArray
67

78

@@ -122,6 +123,16 @@ def test_not_equal_chunk_entries(self):
122123
def test_partly_equals(self): ...
123124

124125

126+
class TestBroadcast:
127+
def test_broadcast_scalar(self):
128+
# regression test
129+
marr = create_manifestarray(shape=(), chunks=())
130+
expanded = np.broadcast_to(marr, shape=(1,))
131+
assert expanded.shape == (1,)
132+
assert expanded.chunks == (1,)
133+
assert expanded.manifest == marr.manifest
134+
135+
125136
# TODO we really need some kind of fixtures to generate useful example data
126137
# The hard part is having an alternative way to get to the expected result of concatenation
127138
class TestConcat:

0 commit comments

Comments
 (0)