|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.11" |
| 3 | +# dependencies = [ |
| 4 | +# "zarr @ git+https://github.com/zarr-developers/zarr-python.git@main", |
| 5 | +# "ml_dtypes==0.5.1", |
| 6 | +# "pytest==8.4.1" |
| 7 | +# ] |
| 8 | +# /// |
| 9 | +# |
| 10 | + |
| 11 | +""" |
| 12 | +Demonstrate how to extend Zarr Python by defining a new data type |
| 13 | +""" |
| 14 | + |
| 15 | +import json |
| 16 | +import sys |
| 17 | +from pathlib import Path |
| 18 | +from typing import ClassVar, Literal, Self, TypeGuard, overload |
| 19 | + |
| 20 | +import ml_dtypes # necessary to add extra dtypes to NumPy |
| 21 | +import numpy as np |
| 22 | +import pytest |
| 23 | + |
| 24 | +import zarr |
| 25 | +from zarr.core.common import JSON, ZarrFormat |
| 26 | +from zarr.core.dtype import ZDType, data_type_registry |
| 27 | +from zarr.core.dtype.common import ( |
| 28 | + DataTypeValidationError, |
| 29 | + DTypeConfig_V2, |
| 30 | + DTypeJSON, |
| 31 | + check_dtype_spec_v2, |
| 32 | +) |
| 33 | + |
| 34 | +# This is the int2 array data type |
| 35 | +int2_dtype_cls = type(np.dtype("int2")) |
| 36 | + |
| 37 | +# This is the int2 scalar type |
| 38 | +int2_scalar_cls = ml_dtypes.int2 |
| 39 | + |
| 40 | + |
| 41 | +class Int2(ZDType[int2_dtype_cls, int2_scalar_cls]): |
| 42 | + """ |
| 43 | + This class provides a Zarr compatibility layer around the int2 data type (the ``dtype`` of a |
| 44 | + NumPy array of type int2) and the int2 scalar type (the ``dtype`` of the scalar value inside an int2 array). |
| 45 | + """ |
| 46 | + |
| 47 | + # This field is as the key for the data type in the internal data type registry, and also |
| 48 | + # as the identifier for the data type when serializaing the data type to disk for zarr v3 |
| 49 | + _zarr_v3_name: ClassVar[Literal["int2"]] = "int2" |
| 50 | + # this field will be used internally |
| 51 | + _zarr_v2_name: ClassVar[Literal["int2"]] = "int2" |
| 52 | + |
| 53 | + # we bind a class variable to the native data type class so we can create instances of it |
| 54 | + dtype_cls = int2_dtype_cls |
| 55 | + |
| 56 | + @classmethod |
| 57 | + def from_native_dtype(cls, dtype: np.dtype) -> Self: |
| 58 | + """Create an instance of this ZDType from a native dtype.""" |
| 59 | + if cls._check_native_dtype(dtype): |
| 60 | + return cls() |
| 61 | + raise DataTypeValidationError( |
| 62 | + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" |
| 63 | + ) |
| 64 | + |
| 65 | + def to_native_dtype(self: Self) -> int2_dtype_cls: |
| 66 | + """Create an int2 dtype instance from this ZDType""" |
| 67 | + return self.dtype_cls() |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[Literal["|b1"], None]]: |
| 71 | + """ |
| 72 | + Type check for Zarr v2-flavored JSON. |
| 73 | +
|
| 74 | + This will check that the input is a dict like this: |
| 75 | + .. code-block:: json |
| 76 | +
|
| 77 | + { |
| 78 | + "name": "int2", |
| 79 | + "object_codec_id": None |
| 80 | + } |
| 81 | +
|
| 82 | + Note that this representation differs from the ``dtype`` field looks like in zarr v2 metadata. |
| 83 | + Specifically, whatever goes into the ``dtype`` field in metadata is assigned to the ``name`` field here. |
| 84 | +
|
| 85 | + See the Zarr docs for more information about the JSON encoding for data types. |
| 86 | + """ |
| 87 | + return ( |
| 88 | + check_dtype_spec_v2(data) and data["name"] == "int2" and data["object_codec_id"] is None |
| 89 | + ) |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[Literal["int2"]]: |
| 93 | + """ |
| 94 | + Type check for Zarr V3-flavored JSON. |
| 95 | +
|
| 96 | + Checks that the input is the string "int2". |
| 97 | + """ |
| 98 | + return data == cls._zarr_v3_name |
| 99 | + |
| 100 | + @classmethod |
| 101 | + def _from_json_v2(cls, data: DTypeJSON) -> Self: |
| 102 | + """ |
| 103 | + Create an instance of this ZDType from Zarr V3-flavored JSON. |
| 104 | + """ |
| 105 | + if cls._check_json_v2(data): |
| 106 | + return cls() |
| 107 | + # This first does a type check on the input, and if that passes we create an instance of the ZDType. |
| 108 | + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v2_name!r}" |
| 109 | + raise DataTypeValidationError(msg) |
| 110 | + |
| 111 | + @classmethod |
| 112 | + def _from_json_v3(cls: type[Self], data: DTypeJSON) -> Self: |
| 113 | + """ |
| 114 | + Create an instance of this ZDType from Zarr V3-flavored JSON. |
| 115 | +
|
| 116 | + This first does a type check on the input, and if that passes we create an instance of the ZDType. |
| 117 | + """ |
| 118 | + if cls._check_json_v3(data): |
| 119 | + return cls() |
| 120 | + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" |
| 121 | + raise DataTypeValidationError(msg) |
| 122 | + |
| 123 | + @overload # type: ignore[override] |
| 124 | + def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[Literal["int2"], None]: ... |
| 125 | + |
| 126 | + @overload |
| 127 | + def to_json(self, zarr_format: Literal[3]) -> Literal["int2"]: ... |
| 128 | + |
| 129 | + def to_json( |
| 130 | + self, zarr_format: ZarrFormat |
| 131 | + ) -> DTypeConfig_V2[Literal["int2"], None] | Literal["int2"]: |
| 132 | + """ |
| 133 | + Serialize this ZDType to v2- or v3-flavored JSON |
| 134 | +
|
| 135 | + If the zarr_format is 2, then return a dict like this: |
| 136 | + .. code-block:: json |
| 137 | +
|
| 138 | + { |
| 139 | + "name": "int2", |
| 140 | + "object_codec_id": None |
| 141 | + } |
| 142 | +
|
| 143 | + If the zarr_format is 3, then return the string "int2" |
| 144 | +
|
| 145 | + """ |
| 146 | + if zarr_format == 2: |
| 147 | + return {"name": "int2", "object_codec_id": None} |
| 148 | + elif zarr_format == 3: |
| 149 | + return self._zarr_v3_name |
| 150 | + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover |
| 151 | + |
| 152 | + def _check_scalar(self, data: object) -> TypeGuard[int | ml_dtypes.int2]: |
| 153 | + """ |
| 154 | + Check if a python object is a valid int2-compatible scalar |
| 155 | +
|
| 156 | + The strictness of this type check is an implementation degree of freedom. |
| 157 | + You could be strict here, and only accept int2 values, or be open and accept any integer |
| 158 | + or any object and rely on exceptions from the int2 constructor that will be called in |
| 159 | + cast_scalar. |
| 160 | + """ |
| 161 | + return isinstance(data, (int, int2_scalar_cls)) |
| 162 | + |
| 163 | + def cast_scalar(self, data: object) -> ml_dtypes.int2: |
| 164 | + """ |
| 165 | + Attempt to cast a python object to an int2. |
| 166 | +
|
| 167 | + We first perform a type check to ensure that the input type is appropriate, and if that |
| 168 | + passes we call the int2 scalar constructor. |
| 169 | + """ |
| 170 | + if self._check_scalar(data): |
| 171 | + return ml_dtypes.int2(data) |
| 172 | + msg = ( |
| 173 | + f"Cannot convert object {data!r} with type {type(data)} to a scalar compatible with the " |
| 174 | + f"data type {self}." |
| 175 | + ) |
| 176 | + raise TypeError(msg) |
| 177 | + |
| 178 | + def default_scalar(self) -> ml_dtypes.int2: |
| 179 | + """ |
| 180 | + Get the default scalar value. This will be used when automatically selecting a fill value. |
| 181 | + """ |
| 182 | + return ml_dtypes.int2(0) |
| 183 | + |
| 184 | + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> int: |
| 185 | + """ |
| 186 | + Convert a python object to a JSON representation of an int2 scalar. |
| 187 | + This is necessary for taking user input for the ``fill_value`` attribute in array metadata. |
| 188 | +
|
| 189 | + In this implementation, we optimistically convert the input to an int, |
| 190 | + and then check that it lies in the acceptable range for this data type. |
| 191 | + """ |
| 192 | + # We could add a type check here, but we don't need to for this example |
| 193 | + val: int = int(data) # type: ignore[call-overload] |
| 194 | + if val not in (-2, -1, 0, 1): |
| 195 | + raise ValueError("Invalid value. Expected -2, -1, 0, or 1.") |
| 196 | + return val |
| 197 | + |
| 198 | + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> ml_dtypes.int2: |
| 199 | + """ |
| 200 | + Read a JSON-serializable value as an int2 scalar. |
| 201 | +
|
| 202 | + We first perform a type check to ensure that the JSON value is well-formed, then call the |
| 203 | + int2 scalar constructor. |
| 204 | +
|
| 205 | + The base definition of this method requires that it take a zarr_format parameter because |
| 206 | + other data types serialize scalars differently in zarr v2 and v3, but we don't use this here. |
| 207 | +
|
| 208 | + """ |
| 209 | + if self._check_scalar(data): |
| 210 | + return ml_dtypes.int2(data) |
| 211 | + raise TypeError(f"Invalid type: {data}. Expected an int.") |
| 212 | + |
| 213 | + |
| 214 | +# after defining dtype class, it must be registered with the data type registry so zarr can use it |
| 215 | +data_type_registry.register(Int2._zarr_v3_name, Int2) |
| 216 | + |
| 217 | + |
| 218 | +# this parametrized function will create arrays in zarr v2 and v3 using our new data type |
| 219 | +@pytest.mark.parametrize("zarr_format", [2, 3]) |
| 220 | +def test_custom_dtype(tmp_path: Path, zarr_format: Literal[2, 3]) -> None: |
| 221 | + # create array and write values |
| 222 | + z_w = zarr.create_array( |
| 223 | + store=tmp_path, shape=(4,), dtype="int2", zarr_format=zarr_format, compressors=None |
| 224 | + ) |
| 225 | + z_w[:] = [-1, -2, 0, 1] |
| 226 | + |
| 227 | + # open the array |
| 228 | + z_r = zarr.open_array(tmp_path, mode="r") |
| 229 | + |
| 230 | + print(z_r.info_complete()) |
| 231 | + |
| 232 | + # look at the array metadata |
| 233 | + if zarr_format == 2: |
| 234 | + meta_file = tmp_path / ".zarray" |
| 235 | + else: |
| 236 | + meta_file = tmp_path / "zarr.json" |
| 237 | + print(json.dumps(json.loads(meta_file.read_text()), indent=2)) |
| 238 | + |
| 239 | + |
| 240 | +if __name__ == "__main__": |
| 241 | + # Run the example with printed output, and a dummy pytest configuration file specified. |
| 242 | + # Without the dummy configuration file, at test time pytest will attempt to use the |
| 243 | + # configuration file in the project root, which will error because Zarr is using some |
| 244 | + # plugins that are not installed in this example. |
| 245 | + sys.exit(pytest.main(["-s", __file__, f"-c {__file__}"])) |
0 commit comments