Skip to content

Commit 9bc52fc

Browse files
authored
Make create_array signatures consistent (#2819)
* add signature tests for async / sync api, and fix mismatched signatures * test for consistent signatures, and make array default fill value consistently 0 * test for async group / group methods * release notes * default fill value is None * expand changelog * Update asynchronous.py * fix diverged signatures * use fill_value = 0 in metadata consolidation test * make signature tests more verbose * make signatures consistent
1 parent 378d5af commit 9bc52fc

File tree

6 files changed

+157
-17
lines changed

6 files changed

+157
-17
lines changed

changes/2819.chore.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Ensure that invocations of ``create_array`` use consistent keyword arguments, with consistent defaults.
2+
Specifically, ``zarr.api.synchronous.create_array`` now takes a ``write_data`` keyword argument; The
3+
``create_array`` method on ``zarr.Group`` takes ``data`` and ``write_data`` keyword arguments. The ``fill_value``
4+
keyword argument of the various invocations of ``create_array`` has been consistently set to ``None``, where previously it was either ``None`` or ``0``.

src/zarr/api/asynchronous.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from zarr.abc.store import Store
1313
from zarr.core.array import (
14+
DEFAULT_FILL_VALUE,
1415
Array,
1516
AsyncArray,
1617
CompressorLike,
@@ -860,10 +861,10 @@ async def open_group(
860861
async def create(
861862
shape: ChunkCoords | int,
862863
*, # Note: this is a change from v2
863-
chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True
864+
chunks: ChunkCoords | int | bool | None = None,
864865
dtype: ZDTypeLike | None = None,
865866
compressor: CompressorLike = "auto",
866-
fill_value: Any | None = 0, # TODO: need type
867+
fill_value: Any | None = DEFAULT_FILL_VALUE,
867868
order: MemoryOrder | None = None,
868869
store: str | StoreLike | None = None,
869870
synchronizer: Any | None = None,

src/zarr/core/group.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from zarr.abc.store import Store, set_or_delete
2121
from zarr.core._info import GroupInfo
2222
from zarr.core.array import (
23+
DEFAULT_FILL_VALUE,
2324
Array,
2425
AsyncArray,
2526
CompressorLike,
@@ -71,6 +72,7 @@
7172
from zarr.core.buffer import Buffer, BufferPrototype
7273
from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike
7374
from zarr.core.common import MemoryOrder
75+
from zarr.core.dtype import ZDTypeLike
7476

7577
logger = logging.getLogger("zarr.group")
7678

@@ -999,22 +1001,24 @@ async def create_array(
9991001
self,
10001002
name: str,
10011003
*,
1002-
shape: ShapeLike,
1003-
dtype: npt.DTypeLike,
1004+
shape: ShapeLike | None = None,
1005+
dtype: ZDTypeLike | None = None,
1006+
data: np.ndarray[Any, np.dtype[Any]] | None = None,
10041007
chunks: ChunkCoords | Literal["auto"] = "auto",
10051008
shards: ShardsLike | None = None,
10061009
filters: FiltersLike = "auto",
10071010
compressors: CompressorsLike = "auto",
10081011
compressor: CompressorLike = "auto",
10091012
serializer: SerializerLike = "auto",
1010-
fill_value: Any | None = 0,
1013+
fill_value: Any | None = DEFAULT_FILL_VALUE,
10111014
order: MemoryOrder | None = None,
10121015
attributes: dict[str, JSON] | None = None,
10131016
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
10141017
dimension_names: DimensionNames = None,
10151018
storage_options: dict[str, Any] | None = None,
10161019
overwrite: bool = False,
1017-
config: ArrayConfig | ArrayConfigLike | None = None,
1020+
config: ArrayConfigLike | None = None,
1021+
write_data: bool = True,
10181022
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
10191023
"""Create an array within this group.
10201024
@@ -1102,6 +1106,11 @@ async def create_array(
11021106
Whether to overwrite an array with the same name in the store, if one exists.
11031107
config : ArrayConfig or ArrayConfigLike, optional
11041108
Runtime configuration for the array.
1109+
write_data : bool
1110+
If a pre-existing array-like object was provided to this function via the ``data`` parameter
1111+
then ``write_data`` determines whether the values in that array-like object should be
1112+
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
1113+
array will be left empty.
11051114
11061115
Returns
11071116
-------
@@ -1116,6 +1125,7 @@ async def create_array(
11161125
name=name,
11171126
shape=shape,
11181127
dtype=dtype,
1128+
data=data,
11191129
chunks=chunks,
11201130
shards=shards,
11211131
filters=filters,
@@ -1130,6 +1140,7 @@ async def create_array(
11301140
storage_options=storage_options,
11311141
overwrite=overwrite,
11321142
config=config,
1143+
write_data=write_data,
11331144
)
11341145

11351146
@deprecated("Use AsyncGroup.create_array instead.")
@@ -2411,22 +2422,24 @@ def create_array(
24112422
self,
24122423
name: str,
24132424
*,
2414-
shape: ShapeLike,
2415-
dtype: npt.DTypeLike,
2425+
shape: ShapeLike | None = None,
2426+
dtype: ZDTypeLike | None = None,
2427+
data: np.ndarray[Any, np.dtype[Any]] | None = None,
24162428
chunks: ChunkCoords | Literal["auto"] = "auto",
24172429
shards: ShardsLike | None = None,
24182430
filters: FiltersLike = "auto",
24192431
compressors: CompressorsLike = "auto",
24202432
compressor: CompressorLike = "auto",
24212433
serializer: SerializerLike = "auto",
2422-
fill_value: Any | None = 0,
2423-
order: MemoryOrder | None = "C",
2434+
fill_value: Any | None = DEFAULT_FILL_VALUE,
2435+
order: MemoryOrder | None = None,
24242436
attributes: dict[str, JSON] | None = None,
24252437
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
24262438
dimension_names: DimensionNames = None,
24272439
storage_options: dict[str, Any] | None = None,
24282440
overwrite: bool = False,
2429-
config: ArrayConfig | ArrayConfigLike | None = None,
2441+
config: ArrayConfigLike | None = None,
2442+
write_data: bool = True,
24302443
) -> Array:
24312444
"""Create an array within this group.
24322445
@@ -2437,10 +2450,13 @@ def create_array(
24372450
name : str
24382451
The name of the array relative to the group. If ``path`` is ``None``, the array will be located
24392452
at the root of the store.
2440-
shape : ChunkCoords
2441-
Shape of the array.
2442-
dtype : npt.DTypeLike
2443-
Data type of the array.
2453+
shape : ChunkCoords, optional
2454+
Shape of the array. Can be ``None`` if ``data`` is provided.
2455+
dtype : npt.DTypeLike | None
2456+
Data type of the array. Can be ``None`` if ``data`` is provided.
2457+
data : Array-like data to use for initializing the array. If this parameter is provided, the
2458+
``shape`` and ``dtype`` parameters must be identical to ``data.shape`` and ``data.dtype``,
2459+
or ``None``.
24442460
chunks : ChunkCoords, optional
24452461
Chunk shape of the array.
24462462
If not specified, default are guessed based on the shape and dtype.
@@ -2514,6 +2530,11 @@ def create_array(
25142530
Whether to overwrite an array with the same name in the store, if one exists.
25152531
config : ArrayConfig or ArrayConfigLike, optional
25162532
Runtime configuration for the array.
2533+
write_data : bool
2534+
If a pre-existing array-like object was provided to this function via the ``data`` parameter
2535+
then ``write_data`` determines whether the values in that array-like object should be
2536+
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
2537+
array will be left empty.
25172538
25182539
Returns
25192540
-------
@@ -2528,6 +2549,7 @@ def create_array(
25282549
name=name,
25292550
shape=shape,
25302551
dtype=dtype,
2552+
data=data,
25312553
chunks=chunks,
25322554
shards=shards,
25332555
fill_value=fill_value,
@@ -2541,6 +2563,7 @@ def create_array(
25412563
overwrite=overwrite,
25422564
storage_options=storage_options,
25432565
config=config,
2566+
write_data=write_data,
25442567
)
25452568
)
25462569
)
@@ -2813,7 +2836,7 @@ def array(
28132836
compressors: CompressorsLike = "auto",
28142837
compressor: CompressorLike = None,
28152838
serializer: SerializerLike = "auto",
2816-
fill_value: Any | None = 0,
2839+
fill_value: Any | None = DEFAULT_FILL_VALUE,
28172840
order: MemoryOrder | None = "C",
28182841
attributes: dict[str, JSON] | None = None,
28192842
chunk_key_encoding: ChunkKeyEncodingLike | None = None,

tests/test_api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import inspect
4+
import pathlib
35
import re
46
from typing import TYPE_CHECKING
57

@@ -8,6 +10,7 @@
810

911
if TYPE_CHECKING:
1012
import pathlib
13+
from collections.abc import Callable
1114

1215
from zarr.abc.store import Store
1316
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
@@ -1216,6 +1219,43 @@ def test_open_array_with_mode_r_plus(store: Store, zarr_format: ZarrFormat) -> N
12161219
z2[:] = 3
12171220

12181221

1222+
@pytest.mark.parametrize(
1223+
("a_func", "b_func"),
1224+
[
1225+
(zarr.api.asynchronous.create_array, zarr.api.synchronous.create_array),
1226+
(zarr.api.asynchronous.save, zarr.api.synchronous.save),
1227+
(zarr.api.asynchronous.save_array, zarr.api.synchronous.save_array),
1228+
(zarr.api.asynchronous.save_group, zarr.api.synchronous.save_group),
1229+
(zarr.api.asynchronous.open_group, zarr.api.synchronous.open_group),
1230+
(zarr.api.asynchronous.create, zarr.api.synchronous.create),
1231+
],
1232+
)
1233+
def test_consistent_signatures(
1234+
a_func: Callable[[object], object], b_func: Callable[[object], object]
1235+
) -> None:
1236+
"""
1237+
Ensure that pairs of functions have the same signature
1238+
"""
1239+
base_sig = inspect.signature(a_func)
1240+
test_sig = inspect.signature(b_func)
1241+
wrong: dict[str, list[object]] = {
1242+
"missing_from_test": [],
1243+
"missing_from_base": [],
1244+
"wrong_type": [],
1245+
}
1246+
for key, value in base_sig.parameters.items():
1247+
if key not in test_sig.parameters:
1248+
wrong["missing_from_test"].append((key, value))
1249+
for key, value in test_sig.parameters.items():
1250+
if key not in base_sig.parameters:
1251+
wrong["missing_from_base"].append((key, value))
1252+
if base_sig.parameters[key] != value:
1253+
wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}})
1254+
assert wrong["missing_from_base"] == []
1255+
assert wrong["missing_from_test"] == []
1256+
assert wrong["wrong_type"] == []
1257+
1258+
12191259
def test_api_exports() -> None:
12201260
"""
12211261
Test that the sync API and the async API export the same objects

tests/test_array.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,43 @@ def test_auto_partition_auto_shards(
970970
assert auto_shards == expected_shards
971971

972972

973+
def test_chunks_and_shards() -> None:
974+
store = StorePath(MemoryStore())
975+
shape = (100, 100)
976+
chunks = (5, 5)
977+
shards = (10, 10)
978+
979+
arr_v3 = zarr.create_array(store=store / "v3", shape=shape, chunks=chunks, dtype="i4")
980+
assert arr_v3.chunks == chunks
981+
assert arr_v3.shards is None
982+
983+
arr_v3_sharding = zarr.create_array(
984+
store=store / "v3_sharding",
985+
shape=shape,
986+
chunks=chunks,
987+
shards=shards,
988+
dtype="i4",
989+
)
990+
assert arr_v3_sharding.chunks == chunks
991+
assert arr_v3_sharding.shards == shards
992+
993+
arr_v2 = zarr.create_array(
994+
store=store / "v2", shape=shape, chunks=chunks, zarr_format=2, dtype="i4"
995+
)
996+
assert arr_v2.chunks == chunks
997+
assert arr_v2.shards is None
998+
999+
1000+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
1001+
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
1002+
@pytest.mark.parametrize(
1003+
("dtype", "fill_value_expected"), [("<U4", ""), ("<S4", b""), ("i", 0), ("f", 0.0)]
1004+
)
1005+
def test_default_fill_value(dtype: str, fill_value_expected: object, store: Store) -> None:
1006+
a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype)
1007+
assert a.fill_value == fill_value_expected
1008+
1009+
9731010
@pytest.mark.parametrize("store", ["memory"], indirect=True)
9741011
class TestCreateArray:
9751012
@staticmethod
@@ -1769,6 +1806,25 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser
17691806
assert all(np.array_equal(r, data) for r in results)
17701807

17711808

1809+
def test_create_array_method_signature() -> None:
1810+
"""
1811+
Test that the signature of the ``AsyncGroup.create_array`` function has nearly the same signature
1812+
as the ``create_array`` function. ``AsyncGroup.create_array`` should take all of the same keyword
1813+
arguments as ``create_array`` except ``store``.
1814+
"""
1815+
1816+
base_sig = inspect.signature(create_array)
1817+
meth_sig = inspect.signature(AsyncGroup.create_array)
1818+
# ignore keyword arguments that are either missing or have different semantics when
1819+
# create_array is invoked as a group method
1820+
ignore_kwargs = {"zarr_format", "store", "name"}
1821+
# TODO: make this test stronger. right now, it only checks that all the parameters in the
1822+
# function signature are used in the method signature. we can be more strict and check that
1823+
# the method signature uses no extra parameters.
1824+
base_params = dict(filter(lambda kv: kv[0] not in ignore_kwargs, base_sig.parameters.items()))
1825+
assert (set(base_params.items()) - set(meth_sig.parameters.items())) == set()
1826+
1827+
17721828
async def test_sharding_coordinate_selection() -> None:
17731829
store = MemoryStore()
17741830
g = zarr.open_group(store, mode="w")

tests/test_group.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,7 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None:
15311531
@pytest.mark.parametrize(
15321532
("a_func", "b_func"),
15331533
[
1534+
(zarr.core.group.AsyncGroup.create_array, zarr.core.group.Group.create_array),
15341535
(zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy),
15351536
(zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy),
15361537
(zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes),
@@ -1546,7 +1547,22 @@ def test_consistent_signatures(
15461547
"""
15471548
base_sig = inspect.signature(a_func)
15481549
test_sig = inspect.signature(b_func)
1549-
assert test_sig.parameters == base_sig.parameters
1550+
wrong: dict[str, list[object]] = {
1551+
"missing_from_test": [],
1552+
"missing_from_base": [],
1553+
"wrong_type": [],
1554+
}
1555+
for key, value in base_sig.parameters.items():
1556+
if key not in test_sig.parameters:
1557+
wrong["missing_from_test"].append((key, value))
1558+
for key, value in test_sig.parameters.items():
1559+
if key not in base_sig.parameters:
1560+
wrong["missing_from_base"].append((key, value))
1561+
if base_sig.parameters[key] != value:
1562+
wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}})
1563+
assert wrong["missing_from_base"] == []
1564+
assert wrong["missing_from_test"] == []
1565+
assert wrong["wrong_type"] == []
15501566

15511567

15521568
@pytest.mark.parametrize("store", ["memory"], indirect=True)

0 commit comments

Comments
 (0)