Skip to content

improvements to parse_dtype #3264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changes/3264.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Expand the range of types accepted by ``parse_data_type`` to include strings and Sequences.
- Move the functionality of ``parse_data_type`` to a new function called ``parse_dtype``. This change
ensures that nomenclature is consistent across the codebase. ``parse_data_type`` remains, so this
change is not breaking.
14 changes: 7 additions & 7 deletions docs/user-guide/data_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,17 @@ attempt data type resolution against *every* data type class, and if, for some r
type matches multiple Zarr data types, we treat this as an error and raise an exception.

If you have a NumPy data type and you want to get the corresponding ``ZDType`` instance, you can use
the ``parse_data_type`` function, which will use the dynamic resolution described above. ``parse_data_type``
the ``parse_dtype`` function, which will use the dynamic resolution described above. ``parse_dtype``
handles a range of input types:

- NumPy data types:

.. code-block:: python

>>> import numpy as np
>>> from zarr.dtype import parse_data_type
>>> from zarr.dtype import parse_dtype
>>> my_dtype = np.dtype('>M8[10s]')
>>> parse_data_type(my_dtype, zarr_format=2)
>>> parse_dtype(my_dtype, zarr_format=2)
DateTime64(endianness='big', scale_factor=10, unit='s')


Expand All @@ -431,7 +431,7 @@ handles a range of input types:
.. code-block:: python

>>> dtype_str = '>M8[10s]'
>>> parse_data_type(dtype_str, zarr_format=2)
>>> parse_dtype(dtype_str, zarr_format=2)
DateTime64(endianness='big', scale_factor=10, unit='s')

- ``ZDType`` instances:
Expand All @@ -440,7 +440,7 @@ handles a range of input types:

>>> from zarr.dtype import DateTime64
>>> zdt = DateTime64(endianness='big', scale_factor=10, unit='s')
>>> parse_data_type(zdt, zarr_format=2) # Use a ZDType (this is a no-op)
>>> parse_dtype(zdt, zarr_format=2) # Use a ZDType (this is a no-op)
DateTime64(endianness='big', scale_factor=10, unit='s')

- Python dictionaries (requires ``zarr_format=3``). These dictionaries must be consistent with the
Expand All @@ -449,7 +449,7 @@ handles a range of input types:
.. code-block:: python

>>> dt_dict = {"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}}
>>> parse_data_type(dt_dict, zarr_format=3)
>>> parse_dtype(dt_dict, zarr_format=3)
DateTime64(endianness='little', scale_factor=10, unit='s')
>>> parse_data_type(dt_dict, zarr_format=3).to_json(zarr_format=3)
>>> parse_dtype(dt_dict, zarr_format=3).to_json(zarr_format=3)
{'name': 'numpy.datetime64', 'configuration': {'unit': 's', 'scale_factor': 10}}
6 changes: 3 additions & 3 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
VariableLengthUTF8,
ZDType,
ZDTypeLike,
parse_data_type,
parse_dtype,
)
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
from zarr.core.indexing import (
Expand Down Expand Up @@ -618,7 +618,7 @@ async def _create(
Deprecated in favor of :func:`zarr.api.asynchronous.create_array`.
"""

dtype_parsed = parse_data_type(dtype, zarr_format=zarr_format)
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
store_path = await make_store_path(store)

shape = parse_shapelike(shape)
Expand Down Expand Up @@ -4239,7 +4239,7 @@ async def init_array(

from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation

zdtype = parse_data_type(dtype, zarr_format=zarr_format)
zdtype = parse_dtype(dtype, zarr_format=zarr_format)
shape_parsed = parse_shapelike(shape)
chunk_key_encoding_parsed = _parse_chunk_key_encoding(
chunk_key_encoding, zarr_format=zarr_format
Expand Down
68 changes: 50 additions & 18 deletions src/zarr/core/dtype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Final, TypeAlias

from zarr.core.dtype.common import (
Expand Down Expand Up @@ -94,6 +95,7 @@
"ZDType",
"data_type_registry",
"parse_data_type",
"parse_dtype",
]

data_type_registry = DataTypeRegistry()
Expand Down Expand Up @@ -188,39 +190,69 @@ def parse_data_type(
zarr_format: ZarrFormat,
) -> ZDType[TBaseDType, TBaseScalar]:
"""
Interpret the input as a ZDType instance.
Interpret the input as a ZDType.

This function wraps ``parse_dtype``. The only difference is the function name. This function may
be deprecated in a future version of Zarr Python in favor of ``parse_dtype``.

Parameters
----------
dtype_spec : ZDTypeLike
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
can be converted into a native dtype.
zarr_format : ZarrFormat
The Zarr format version.

Returns
-------
ZDType[TBaseDType, TBaseScalar]
The ZDType corresponding to the input.

Examples
--------
>>> parse_dtype("int32", zarr_format=2)
Int32(endianness="little")
"""
Comment on lines +193 to +216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say bin the docstirng here to avoid duplication, and just point to parse_dtype.

return parse_dtype(dtype_spec, zarr_format=zarr_format)


def parse_dtype(
dtype_spec: ZDTypeLike,
*,
zarr_format: ZarrFormat,
) -> ZDType[TBaseDType, TBaseScalar]:
"""
Interpret the input as a ZDType.

Parameters
----------
dtype_spec : ZDTypeLike
The input to be interpreted as a ZDType instance. This could be a native data type
(e.g., a NumPy data type), a Python object that can be converted into a native data type,
a ZDType instance (in which case the input is returned unchanged), or a JSON object
representation of a data type.
The input to be interpreted as a ZDType. This could be a ZDType, which will be returned
directly, or a JSON representation of a ZDType, or a native dtype, or a python object that
can be converted into a native dtype.
zarr_format : ZarrFormat
The zarr format version.
The Zarr format version.

Returns
-------
ZDType[TBaseDType, TBaseScalar]
The ZDType instance corresponding to the input.
The ZDType corresponding to the input.

Examples
--------
>>> from zarr.dtype import parse_data_type
>>> import numpy as np
>>> parse_data_type("int32", zarr_format=2)
Int32(endianness='little')
>>> parse_data_type(np.dtype('S10'), zarr_format=2)
NullTerminatedBytes(length=10)
>>> parse_data_type({"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}}, zarr_format=3)
DateTime64(endianness='little', scale_factor=10, unit='s')
>>> parse_dtype("int32", zarr_format=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, does this not need an import? not an issue, just a question

Int32(endianness="little")
"""
if isinstance(dtype_spec, ZDType):
return dtype_spec
# dict and zarr_format 3 means that we have a JSON object representation of the dtype
if zarr_format == 3 and isinstance(dtype_spec, Mapping):
return get_data_type_from_json(dtype_spec, zarr_format=3)
# First attempt to interpret the input as JSON
if isinstance(dtype_spec, Mapping | str | Sequence):
try:
return get_data_type_from_json(dtype_spec, zarr_format=3) # type: ignore[arg-type]
except ValueError:
# no data type matched this JSON-like input
pass
if dtype_spec in VLEN_UTF8_ALIAS:
# If the dtype request is one of the aliases for variable-length UTF-8 strings,
# return that dtype.
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ZDType,
data_type_registry,
parse_data_type,
parse_dtype,
)

__all__ = [
Expand Down Expand Up @@ -84,4 +85,5 @@
"data_type_registry",
"data_type_registry",
"parse_data_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"parse_data_type",

"parse_dtype",
]
4 changes: 2 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
VariableLengthBytes,
VariableLengthUTF8,
ZDType,
parse_data_type,
parse_dtype,
)
from zarr.core.dtype.common import ENDIANNESS_STR, EndiannessStr
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
Expand Down Expand Up @@ -1308,7 +1308,7 @@ async def test_v2_chunk_encoding(
filters=filters,
)
filters_expected, compressor_expected = _parse_chunk_encoding_v2(
filters=filters, compressor=compressors, dtype=parse_data_type(dtype, zarr_format=2)
filters=filters, compressor=compressors, dtype=parse_dtype(dtype, zarr_format=2)
)
assert arr.metadata.zarr_format == 2 # guard for mypy
assert arr.metadata.compressor == compressor_expected
Expand Down
66 changes: 40 additions & 26 deletions tests/test_dtype_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
AnyDType,
Bool,
DataTypeRegistry,
DateTime64,
FixedLengthUTF32,
Int8,
Int16,
TBaseDType,
TBaseScalar,
VariableLengthUTF8,
ZDType,
data_type_registry,
get_data_type_from_json,
parse_data_type,
parse_dtype,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -174,28 +171,45 @@ def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None:
data_type_registry.unregister(TestDataType._zarr_v3_name)


@pytest.mark.parametrize(
("dtype_params", "expected", "zarr_format"),
[
("str", VariableLengthUTF8(), 2),
("str", VariableLengthUTF8(), 3),
("int8", Int8(), 3),
(Int8(), Int8(), 3),
(">i2", Int16(endianness="big"), 2),
("datetime64[10s]", DateTime64(unit="s", scale_factor=10), 2),
(
{"name": "numpy.datetime64", "configuration": {"unit": "s", "scale_factor": 10}},
DateTime64(unit="s", scale_factor=10),
3,
),
],
)
def test_parse_data_type(
dtype_params: Any, expected: ZDType[Any, Any], zarr_format: ZarrFormat
) -> None:
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
def test_parse_data_type(data_type: ZDType[Any, Any], zarr_format: ZarrFormat) -> None:
"""
Test that parse_data_type accepts alternative representations of ZDType instances, and resolves
Test that parse_dtype accepts alternative representations of ZDType instances, and resolves
those inputs to the expected ZDType instance.
"""
observed = parse_data_type(dtype_params, zarr_format=zarr_format)
assert observed == expected
dtype_spec: Any
if zarr_format == 2:
dtype_spec = data_type.to_json(zarr_format=zarr_format)["name"]
else:
dtype_spec = data_type.to_json(zarr_format=zarr_format)
if dtype_spec == "|O":
msg = "Zarr data type resolution from object failed."
with pytest.raises(ValueError, match=msg):
parse_dtype(dtype_spec, zarr_format=zarr_format)
else:
observed = parse_dtype(dtype_spec, zarr_format=zarr_format)
assert observed == data_type


@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
def test_parse_data_type_funcs(data_type: ZDType[Any, Any], zarr_format: ZarrFormat) -> None:
"""
Test that parse_data_type generates the same output as parse_dtype.
"""
dtype_spec: Any
if zarr_format == 2:
dtype_spec = data_type.to_json(zarr_format=zarr_format)["name"]
else:
dtype_spec = data_type.to_json(zarr_format=zarr_format)
if dtype_spec == "|O":
msg = "Zarr data type resolution from object failed."
with pytest.raises(ValueError, match=msg):
parse_dtype(dtype_spec, zarr_format=zarr_format)
with pytest.raises(ValueError, match=msg):
parse_data_type(dtype_spec, zarr_format=zarr_format)
else:
assert parse_dtype(dtype_spec, zarr_format=zarr_format) == parse_data_type(
dtype_spec, zarr_format=zarr_format
)
4 changes: 2 additions & 2 deletions tests/test_metadata/test_consolidated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
open_consolidated,
)
from zarr.core.buffer import cpu, default_buffer_prototype
from zarr.core.dtype import parse_data_type
from zarr.core.dtype import parse_dtype
from zarr.core.group import ConsolidatedMetadata, GroupMetadata
from zarr.core.metadata import ArrayV3Metadata
from zarr.core.metadata.v2 import ArrayV2Metadata
Expand Down Expand Up @@ -504,7 +504,7 @@ async def test_consolidated_metadata_backwards_compatibility(
async def test_consolidated_metadata_v2(self):
store = zarr.storage.MemoryStore()
g = await AsyncGroup.from_store(store, attributes={"key": "root"}, zarr_format=2)
dtype = parse_data_type("uint8", zarr_format=2)
dtype = parse_dtype("uint8", zarr_format=2)
await g.create_array(name="a", shape=(1,), attributes={"key": "a"}, dtype=dtype)
g1 = await g.create_group(name="g1", attributes={"key": "g1"})
await g1.create_group(name="g2", attributes={"key": "g2"})
Expand Down
Loading