Skip to content

Commit 14a544c

Browse files
authored
move ensure_dtype_not_object from conventions to backends (#9828)
* move ensure_dtype_not_object from conventions to backends * add whats-new.rst entry
1 parent 7fd572d commit 14a544c

File tree

7 files changed

+142
-125
lines changed

7 files changed

+142
-125
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ Documentation
4747

4848
Internal Changes
4949
~~~~~~~~~~~~~~~~
50-
51-
50+
- Move non-CF related ``ensure_dtype_not_object`` from conventions to backends (:pull:`9828`).
51+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
5252

5353
.. _whats-new.2024.11.0:
5454

xarray/backends/common.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,36 @@
44
import os
55
import time
66
import traceback
7-
from collections.abc import Iterable, Mapping, Sequence
7+
from collections.abc import Hashable, Iterable, Mapping, Sequence
88
from glob import glob
9-
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload
9+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload
1010

1111
import numpy as np
12+
import pandas as pd
1213

14+
from xarray.coding import strings, variables
15+
from xarray.coding.variables import SerializationWarning
1316
from xarray.conventions import cf_encoder
1417
from xarray.core import indexing
15-
from xarray.core.datatree import DataTree
18+
from xarray.core.datatree import DataTree, Variable
1619
from xarray.core.types import ReadBuffer
1720
from xarray.core.utils import (
1821
FrozenDict,
1922
NdimSizeLenMixin,
2023
attempt_import,
24+
emit_user_level_warning,
2125
is_remote_uri,
2226
)
2327
from xarray.namedarray.parallelcompat import get_chunked_array_type
2428
from xarray.namedarray.pycompat import is_chunked_array
29+
from xarray.namedarray.utils import is_duck_dask_array
2530

2631
if TYPE_CHECKING:
2732
from xarray.core.dataset import Dataset
2833
from xarray.core.types import NestedSequence
2934

35+
T_Name = Union[Hashable, None]
36+
3037
# Create a logger object, but don't add any handlers. Leave that to user code.
3138
logger = logging.getLogger(__name__)
3239

@@ -527,13 +534,111 @@ def set_dimensions(self, variables, unlimited_dims=None):
527534
self.set_dimension(dim, length, is_unlimited)
528535

529536

537+
def _infer_dtype(array, name=None):
538+
"""Given an object array with no missing values, infer its dtype from all elements."""
539+
if array.dtype.kind != "O":
540+
raise TypeError("infer_type must be called on a dtype=object array")
541+
542+
if array.size == 0:
543+
return np.dtype(float)
544+
545+
native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
546+
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
547+
raise ValueError(
548+
"unable to infer dtype on variable {!r}; object array "
549+
"contains mixed native types: {}".format(
550+
name, ", ".join(x.__name__ for x in native_dtypes)
551+
)
552+
)
553+
554+
element = array[(0,) * array.ndim]
555+
# We use the base types to avoid subclasses of bytes and str (which might
556+
# not play nice with e.g. hdf5 datatypes), such as those from numpy
557+
if isinstance(element, bytes):
558+
return strings.create_vlen_dtype(bytes)
559+
elif isinstance(element, str):
560+
return strings.create_vlen_dtype(str)
561+
562+
dtype = np.array(element).dtype
563+
if dtype.kind != "O":
564+
return dtype
565+
566+
raise ValueError(
567+
f"unable to infer dtype on variable {name!r}; xarray "
568+
"cannot serialize arbitrary Python objects"
569+
)
570+
571+
572+
def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
573+
"""Create a copy of an array with the given dtype.
574+
575+
We use this instead of np.array() to ensure that custom object dtypes end
576+
up on the resulting array.
577+
"""
578+
result = np.empty(data.shape, dtype)
579+
result[...] = data
580+
return result
581+
582+
583+
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
584+
if var.dtype.kind == "O":
585+
dims, data, attrs, encoding = variables.unpack_for_encoding(var)
586+
587+
# leave vlen dtypes unchanged
588+
if strings.check_vlen_dtype(data.dtype) is not None:
589+
return var
590+
591+
if is_duck_dask_array(data):
592+
emit_user_level_warning(
593+
f"variable {name} has data in the form of a dask array with "
594+
"dtype=object, which means it is being loaded into memory "
595+
"to determine a data type that can be safely stored on disk. "
596+
"To avoid this, coerce this variable to a fixed-size dtype "
597+
"with astype() before saving it.",
598+
category=SerializationWarning,
599+
)
600+
data = data.compute()
601+
602+
missing = pd.isnull(data)
603+
if missing.any():
604+
# nb. this will fail for dask.array data
605+
non_missing_values = data[~missing]
606+
inferred_dtype = _infer_dtype(non_missing_values, name)
607+
608+
# There is no safe bit-pattern for NA in typical binary string
609+
# formats, we so can't set a fill_value. Unfortunately, this means
610+
# we can't distinguish between missing values and empty strings.
611+
fill_value: bytes | str
612+
if strings.is_bytes_dtype(inferred_dtype):
613+
fill_value = b""
614+
elif strings.is_unicode_dtype(inferred_dtype):
615+
fill_value = ""
616+
else:
617+
# insist on using float for numeric values
618+
if not np.issubdtype(inferred_dtype, np.floating):
619+
inferred_dtype = np.dtype(float)
620+
fill_value = inferred_dtype.type(np.nan)
621+
622+
data = _copy_with_dtype(data, dtype=inferred_dtype)
623+
data[missing] = fill_value
624+
else:
625+
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
626+
627+
assert data.dtype.kind != "O" or data.dtype.metadata
628+
var = Variable(dims, data, attrs, encoding, fastpath=True)
629+
return var
630+
631+
530632
class WritableCFDataStore(AbstractWritableDataStore):
531633
__slots__ = ()
532634

533635
def encode(self, variables, attributes):
534636
# All NetCDF files get CF encoded by default, without this attempting
535637
# to write times, for example, would fail.
536638
variables, attributes = cf_encoder(variables, attributes)
639+
variables = {
640+
k: ensure_dtype_not_object(v, name=k) for k, v in variables.items()
641+
}
537642
variables = {k: self.encode_variable(v) for k, v in variables.items()}
538643
attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
539644
return variables, attributes

xarray/backends/zarr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_encode_variable_name,
2020
_normalize_path,
2121
datatree_from_dict_with_io_cleanup,
22+
ensure_dtype_not_object,
2223
)
2324
from xarray.backends.store import StoreBackendEntrypoint
2425
from xarray.core import indexing
@@ -507,6 +508,7 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
507508
"""
508509

509510
var = conventions.encode_cf_variable(var, name=name)
511+
var = ensure_dtype_not_object(var, name=name)
510512

511513
# zarr allows unicode, but not variable-length strings, so it's both
512514
# simpler and more compact to always encode as UTF-8 explicitly.

xarray/conventions.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
77

88
import numpy as np
9-
import pandas as pd
109

1110
from xarray.coding import strings, times, variables
1211
from xarray.coding.variables import SerializationWarning, pop_to
@@ -50,41 +49,6 @@
5049
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]
5150

5251

53-
def _infer_dtype(array, name=None):
54-
"""Given an object array with no missing values, infer its dtype from all elements."""
55-
if array.dtype.kind != "O":
56-
raise TypeError("infer_type must be called on a dtype=object array")
57-
58-
if array.size == 0:
59-
return np.dtype(float)
60-
61-
native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
62-
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
63-
raise ValueError(
64-
"unable to infer dtype on variable {!r}; object array "
65-
"contains mixed native types: {}".format(
66-
name, ", ".join(x.__name__ for x in native_dtypes)
67-
)
68-
)
69-
70-
element = array[(0,) * array.ndim]
71-
# We use the base types to avoid subclasses of bytes and str (which might
72-
# not play nice with e.g. hdf5 datatypes), such as those from numpy
73-
if isinstance(element, bytes):
74-
return strings.create_vlen_dtype(bytes)
75-
elif isinstance(element, str):
76-
return strings.create_vlen_dtype(str)
77-
78-
dtype = np.array(element).dtype
79-
if dtype.kind != "O":
80-
return dtype
81-
82-
raise ValueError(
83-
f"unable to infer dtype on variable {name!r}; xarray "
84-
"cannot serialize arbitrary Python objects"
85-
)
86-
87-
8852
def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
8953
# only the pandas multi-index dimension coordinate cannot be serialized (tuple values)
9054
if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
@@ -99,67 +63,6 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
9963
)
10064

10165

102-
def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
103-
"""Create a copy of an array with the given dtype.
104-
105-
We use this instead of np.array() to ensure that custom object dtypes end
106-
up on the resulting array.
107-
"""
108-
result = np.empty(data.shape, dtype)
109-
result[...] = data
110-
return result
111-
112-
113-
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
114-
# TODO: move this from conventions to backends? (it's not CF related)
115-
if var.dtype.kind == "O":
116-
dims, data, attrs, encoding = variables.unpack_for_encoding(var)
117-
118-
# leave vlen dtypes unchanged
119-
if strings.check_vlen_dtype(data.dtype) is not None:
120-
return var
121-
122-
if is_duck_dask_array(data):
123-
emit_user_level_warning(
124-
f"variable {name} has data in the form of a dask array with "
125-
"dtype=object, which means it is being loaded into memory "
126-
"to determine a data type that can be safely stored on disk. "
127-
"To avoid this, coerce this variable to a fixed-size dtype "
128-
"with astype() before saving it.",
129-
category=SerializationWarning,
130-
)
131-
data = data.compute()
132-
133-
missing = pd.isnull(data)
134-
if missing.any():
135-
# nb. this will fail for dask.array data
136-
non_missing_values = data[~missing]
137-
inferred_dtype = _infer_dtype(non_missing_values, name)
138-
139-
# There is no safe bit-pattern for NA in typical binary string
140-
# formats, we so can't set a fill_value. Unfortunately, this means
141-
# we can't distinguish between missing values and empty strings.
142-
fill_value: bytes | str
143-
if strings.is_bytes_dtype(inferred_dtype):
144-
fill_value = b""
145-
elif strings.is_unicode_dtype(inferred_dtype):
146-
fill_value = ""
147-
else:
148-
# insist on using float for numeric values
149-
if not np.issubdtype(inferred_dtype, np.floating):
150-
inferred_dtype = np.dtype(float)
151-
fill_value = inferred_dtype.type(np.nan)
152-
153-
data = _copy_with_dtype(data, dtype=inferred_dtype)
154-
data[missing] = fill_value
155-
else:
156-
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
157-
158-
assert data.dtype.kind != "O" or data.dtype.metadata
159-
var = Variable(dims, data, attrs, encoding, fastpath=True)
160-
return var
161-
162-
16366
def encode_cf_variable(
16467
var: Variable, needs_copy: bool = True, name: T_Name = None
16568
) -> Variable:
@@ -196,9 +99,6 @@ def encode_cf_variable(
19699
]:
197100
var = coder.encode(var, name=name)
198101

199-
# TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
200-
var = ensure_dtype_not_object(var, name=name)
201-
202102
for attr_name in CF_RELATED_DATA:
203103
pop_to(var.encoding, var.attrs, attr_name)
204104
return var

xarray/tests/test_backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,22 @@ def test_multiindex_not_implemented(self) -> None:
14001400
with self.roundtrip(ds_reset) as actual:
14011401
assert_identical(actual, ds_reset)
14021402

1403+
@requires_dask
1404+
def test_string_object_warning(self) -> None:
1405+
original = Dataset(
1406+
{
1407+
"x": (
1408+
[
1409+
"y",
1410+
],
1411+
np.array(["foo", "bar"], dtype=object),
1412+
)
1413+
}
1414+
).chunk()
1415+
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
1416+
with self.roundtrip(original) as actual:
1417+
assert_identical(original, actual)
1418+
14031419

14041420
class NetCDFBase(CFEncodedBase):
14051421
"""Tests for all netCDF3 and netCDF4 backends."""

xarray/tests/test_backends_common.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pytest
45

5-
from xarray.backends.common import robust_getitem
6+
from xarray.backends.common import _infer_dtype, robust_getitem
67

78

89
class DummyFailure(Exception):
@@ -30,3 +31,15 @@ def test_robust_getitem() -> None:
3031
array = DummyArray(failures=3)
3132
with pytest.raises(DummyFailure):
3233
robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2)
34+
35+
36+
@pytest.mark.parametrize(
37+
"data",
38+
[
39+
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
40+
np.array([["x", 1], ["y", 2]], dtype="object"),
41+
],
42+
)
43+
def test_infer_dtype_error_on_mixed_types(data):
44+
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
45+
_infer_dtype(data, "test")

xarray/tests/test_conventions.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,6 @@ def test_emit_coordinates_attribute_in_encoding(self) -> None:
249249
assert enc["b"].attrs.get("coordinates") == "t"
250250
assert "coordinates" not in enc["b"].encoding
251251

252-
@requires_dask
253-
def test_string_object_warning(self) -> None:
254-
original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk()
255-
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
256-
encoded = conventions.encode_cf_variable(original)
257-
assert_identical(original, encoded)
258-
259252

260253
@requires_cftime
261254
class TestDecodeCF:
@@ -593,18 +586,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
593586
pass
594587

595588

596-
@pytest.mark.parametrize(
597-
"data",
598-
[
599-
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
600-
np.array([["x", 1], ["y", 2]], dtype="object"),
601-
],
602-
)
603-
def test_infer_dtype_error_on_mixed_types(data):
604-
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
605-
conventions._infer_dtype(data, "test")
606-
607-
608589
class TestDecodeCFVariableWithArrayUnits:
609590
def test_decode_cf_variable_with_array_units(self) -> None:
610591
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})

0 commit comments

Comments
 (0)