Skip to content

Commit 1b5c2ec

Browse files
committed
Fix AsyncGroup.create_dataset() dtype handling and optimize tests
1 parent 7584b96 commit 1b5c2ec

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

changes/3050.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Fixed potential error in `AsyncGroup.create_dataset()` where `dtype` argument could be missing when calling `create_array()`

src/zarr/core/group.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,8 +1155,11 @@ async def create_dataset(
11551155
# create_dataset in zarr 2.x requires shape but not dtype if data is
11561156
# provided. Allow this configuration by inferring dtype from data if
11571157
# necessary and passing it to create_array
1158-
if "dtype" not in kwargs and data is not None:
1159-
kwargs["dtype"] = data.dtype
1158+
if "dtype" not in kwargs:
1159+
if data is not None:
1160+
kwargs["dtype"] = data.dtype
1161+
else:
1162+
raise ValueError("dtype must be provided if data is None")
11601163
array = await self.create_array(name, shape=shape, **kwargs)
11611164
if data is not None:
11621165
await array.setitem(slice(None), data)
@@ -2544,12 +2547,17 @@ def require_dataset(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Arra
25442547
----------
25452548
name : str
25462549
Array name.
2547-
**kwargs :
2548-
See :func:`zarr.Group.create_dataset`.
2550+
shape : int or tuple of ints
2551+
Array shape.
2552+
dtype : str or dtype, optional
2553+
NumPy dtype.
2554+
exact : bool, optional
2555+
If True, require `dtype` to match exactly. If false, require
2556+
`dtype` can be cast from array dtype.
25492557
25502558
Returns
25512559
-------
2552-
a : Array
2560+
a : AsyncArray
25532561
"""
25542562
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
25552563

@@ -2562,12 +2570,17 @@ def require_array(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array:
25622570
----------
25632571
name : str
25642572
Array name.
2565-
**kwargs :
2566-
See :func:`zarr.Group.create_array`.
2573+
shape : int or tuple of ints
2574+
Array shape.
2575+
dtype : str or dtype, optional
2576+
NumPy dtype.
2577+
exact : bool, optional
2578+
If True, require `dtype` to match exactly. If false, require
2579+
`dtype` can be cast from array dtype.
25672580
25682581
Returns
25692582
-------
2570-
a : Array
2583+
a : AsyncArray
25712584
"""
25722585
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
25732586

tests/test_properties.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import hypothesis.extra.numpy as npst
1515
import hypothesis.strategies as st
16-
from hypothesis import assume, given, settings
16+
from hypothesis import assume, given, settings, HealthCheck
1717

1818
from zarr.abc.store import Store
1919
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
@@ -75,7 +75,7 @@ def deep_equal(a: Any, b: Any) -> bool:
7575

7676
return a == b
7777

78-
78+
@settings(deadline=None) # Increased from default 200ms to None
7979
@given(data=st.data(), zarr_format=zarr_formats)
8080
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
8181
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
@@ -117,10 +117,11 @@ def test_basic_indexing(data: st.DataObject) -> None:
117117
assert_array_equal(nparray, zarray[:])
118118

119119

120+
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
120121
@given(data=st.data())
121122
def test_oindex(data: st.DataObject) -> None:
122123
# integer_array_indices can't handle 0-size dimensions.
123-
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
124+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
124125
nparray = zarray[:]
125126

126127
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
@@ -138,15 +139,17 @@ def test_oindex(data: st.DataObject) -> None:
138139
assert_array_equal(nparray, zarray[:])
139140

140141

142+
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
141143
@given(data=st.data())
142144
def test_vindex(data: st.DataObject) -> None:
143145
# integer_array_indices can't handle 0-size dimensions.
144-
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
146+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
145147
nparray = zarray[:]
146148

147149
indexer = data.draw(
148150
npst.integer_array_indices(
149-
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None)
151+
shape=nparray.shape,
152+
result_shape=npst.array_shapes(min_side=1, max_dims=2, max_side=8)
150153
)
151154
)
152155
actual = zarray.vindex[indexer]

0 commit comments

Comments
 (0)