Skip to content

Commit 0f0563b

Browse files
dcheriand-v-b
andauthored
Update stateful/property tests. (#3161)
* Update stateful/property tests. Add actions to 1. overwrite data with oindex 2. read and compare a full array * Reduce freqquency of clear --------- Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent c1ce2fd commit 0f0563b

File tree

3 files changed

+101
-42
lines changed

3 files changed

+101
-42
lines changed

src/zarr/testing/stateful.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import builtins
2-
from typing import Any
2+
import functools
3+
from collections.abc import Callable
4+
from typing import Any, TypeVar, cast
35

46
import hypothesis.extra.numpy as npst
57
import hypothesis.strategies as st
@@ -24,15 +26,43 @@
2426
from zarr.testing.strategies import (
2527
basic_indices,
2628
chunk_paths,
29+
dimension_names,
2730
key_ranges,
2831
node_names,
2932
np_array_and_chunks,
30-
numpy_arrays,
33+
orthogonal_indices,
3134
)
3235
from zarr.testing.strategies import keys as zarr_keys
3336

3437
MAX_BINARY_SIZE = 100
3538

39+
F = TypeVar("F", bound=Callable[..., Any])
40+
41+
42+
def with_frequency(frequency: float) -> Callable[[F], F]:
43+
"""This needs to be deterministic for hypothesis replaying"""
44+
45+
def decorator(func: F) -> F:
46+
counter_attr = f"__{func.__name__}_counter"
47+
48+
@functools.wraps(func)
49+
def wrapper(*args: Any, **kwargs: Any) -> Any:
50+
return func(*args, **kwargs)
51+
52+
@precondition
53+
def frequency_check(f: Any) -> Any:
54+
if not hasattr(f, counter_attr):
55+
setattr(f, counter_attr, 0)
56+
57+
current_count = getattr(f, counter_attr) + 1
58+
setattr(f, counter_attr, current_count)
59+
60+
return (current_count * frequency) % 1.0 >= (1.0 - frequency)
61+
62+
return cast(F, frequency_check(wrapper))
63+
64+
return decorator
65+
3666

3767
def split_prefix_name(path: str) -> tuple[str, str]:
3868
split = path.rsplit("/", maxsplit=1)
@@ -90,11 +120,7 @@ def add_group(self, name: str, data: DataObject) -> None:
90120
zarr.group(store=self.store, path=path)
91121
zarr.group(store=self.model, path=path)
92122

93-
@rule(
94-
data=st.data(),
95-
name=node_names,
96-
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
97-
)
123+
@rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks())
98124
def add_array(
99125
self,
100126
data: DataObject,
@@ -122,12 +148,17 @@ def add_array(
122148
path=path,
123149
store=store,
124150
fill_value=fill_value,
151+
zarr_format=3,
152+
dimension_names=data.draw(
153+
dimension_names(ndim=array.ndim), label="dimension names"
154+
),
125155
# Chose bytes codec to avoid wasting time compressing the data being written
126156
codecs=[BytesCodec()],
127157
)
128158
self.all_arrays.add(path)
129159

130160
@rule()
161+
@with_frequency(0.25)
131162
def clear(self) -> None:
132163
note("clearing")
133164
import zarr
@@ -192,6 +223,14 @@ def delete_chunk(self, data: DataObject) -> None:
192223
self._sync(self.model.delete(path))
193224
self._sync(self.store.delete(path))
194225

226+
@precondition(lambda self: bool(self.all_arrays))
227+
@rule(data=st.data())
228+
def check_array(self, data: DataObject) -> None:
229+
path = data.draw(st.sampled_from(sorted(self.all_arrays)))
230+
actual = zarr.open_array(self.store, path=path)[:]
231+
expected = zarr.open_array(self.model, path=path)[:]
232+
np.testing.assert_equal(actual, expected)
233+
195234
@precondition(lambda self: bool(self.all_arrays))
196235
@rule(data=st.data())
197236
def overwrite_array_basic_indexing(self, data: DataObject) -> None:
@@ -206,6 +245,20 @@ def overwrite_array_basic_indexing(self, data: DataObject) -> None:
206245
model_array[slicer] = new_data
207246
store_array[slicer] = new_data
208247

248+
@precondition(lambda self: bool(self.all_arrays))
249+
@rule(data=st.data())
250+
def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None:
251+
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
252+
model_array = zarr.open_array(path=array, store=self.model)
253+
store_array = zarr.open_array(path=array, store=self.store)
254+
indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape))
255+
note(f"overwriting array orthogonal {indexer=}")
256+
new_data = data.draw(
257+
npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr]
258+
)
259+
model_array.oindex[indexer] = new_data
260+
store_array.oindex[indexer] = new_data
261+
209262
@precondition(lambda self: bool(self.all_arrays))
210263
@rule(data=st.data())
211264
def resize_array(self, data: DataObject) -> None:

src/zarr/testing/strategies.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
4343
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))
4444

4545

46-
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
46+
def dtypes() -> st.SearchStrategy[np.dtype[Any]]:
4747
return (
4848
npst.boolean_dtypes()
4949
| npst.integer_dtypes(endianness="=")
@@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
5757
)
5858

5959

60+
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61+
return dtypes()
62+
63+
6064
def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61-
return (
62-
npst.boolean_dtypes()
63-
| npst.integer_dtypes(endianness="=")
64-
| npst.unsigned_integer_dtypes(endianness="=")
65-
| npst.floating_dtypes(endianness="=")
66-
| npst.complex_number_dtypes(endianness="=")
67-
| npst.byte_string_dtypes(endianness="=")
68-
| npst.unicode_string_dtypes(endianness="=")
69-
| npst.datetime64_dtypes(endianness="=")
70-
| npst.timedelta64_dtypes(endianness="=")
71-
)
65+
return dtypes()
7266

7367

7468
def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
@@ -144,7 +138,7 @@ def array_metadata(
144138
shape = draw(array_shapes())
145139
ndim = len(shape)
146140
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
147-
np_dtype = draw(v3_dtypes())
141+
np_dtype = draw(dtypes())
148142
dtype = get_data_type_from_native_dtype(np_dtype)
149143
fill_value = draw(npst.from_dtype(np_dtype))
150144
if zarr_format == 2:
@@ -179,14 +173,12 @@ def numpy_arrays(
179173
*,
180174
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
181175
dtype: np.dtype[Any] | None = None,
182-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
183176
) -> npt.NDArray[Any]:
184177
"""
185178
Generate numpy arrays that can be saved in the provided Zarr format.
186179
"""
187-
zarr_format = draw(zarr_formats)
188180
if dtype is None:
189-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
181+
dtype = draw(dtypes())
190182
if np.issubdtype(dtype, np.str_):
191183
safe_unicode_strings = safe_unicode_for_dtype(dtype)
192184
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
@@ -255,17 +247,24 @@ def arrays(
255247
attrs: st.SearchStrategy = attrs,
256248
zarr_formats: st.SearchStrategy = zarr_formats,
257249
) -> Array:
258-
store = draw(stores)
259-
path = draw(paths)
260-
name = draw(array_names)
261-
attributes = draw(attrs)
262-
zarr_format = draw(zarr_formats)
250+
store = draw(stores, label="store")
251+
path = draw(paths, label="array parent")
252+
name = draw(array_names, label="array name")
253+
attributes = draw(attrs, label="attributes")
254+
zarr_format = draw(zarr_formats, label="zarr format")
263255
if arrays is None:
264-
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
265-
nparray = draw(arrays)
266-
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
256+
arrays = numpy_arrays(shapes=shapes)
257+
nparray = draw(arrays, label="array data")
258+
chunk_shape = draw(chunk_shapes(shape=nparray.shape), label="chunk shape")
259+
extra_kwargs = {}
267260
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
268-
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
261+
shard_shape = draw(
262+
st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape),
263+
label="shard shape",
264+
)
265+
extra_kwargs["dimension_names"] = draw(
266+
dimension_names(ndim=nparray.ndim), label="dimension names"
267+
)
269268
else:
270269
shard_shape = None
271270
# test that None works too.
@@ -286,6 +285,7 @@ def arrays(
286285
attributes=attributes,
287286
# compressor=compressor, # FIXME
288287
fill_value=fill_value,
288+
**extra_kwargs,
289289
)
290290

291291
assert isinstance(a, Array)
@@ -385,13 +385,19 @@ def orthogonal_indices(
385385
npindexer = []
386386
ndim = len(shape)
387387
for axis, size in enumerate(shape):
388-
val = draw(
389-
npst.integer_array_indices(
388+
if size != 0:
389+
strategy = npst.integer_array_indices(
390390
shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1)
391-
)
392-
| basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
393-
.map(lambda x: (x,) if not isinstance(x, tuple) else x) # bare ints, slices
394-
.filter(bool) # skip empty tuple
391+
) | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
392+
else:
393+
strategy = basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
394+
395+
val = draw(
396+
strategy
397+
# bare ints, slices
398+
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
399+
# skip empty tuple
400+
.filter(bool)
395401
)
396402
(idxr,) = val
397403
if isinstance(idxr, int):

tests/test_properties.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def deep_equal(a: Any, b: Any) -> bool:
7676

7777

7878
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
79-
@given(data=st.data(), zarr_format=zarr_formats)
80-
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
81-
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
82-
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
79+
@given(data=st.data())
80+
def test_array_roundtrip(data: st.DataObject) -> None:
81+
nparray = data.draw(numpy_arrays())
82+
zarray = data.draw(arrays(arrays=st.just(nparray)))
8383
assert_array_equal(nparray, zarray[:])
8484

8585

0 commit comments

Comments
 (0)