Skip to content

Commit dc0931a

Browse files
andersy005mathausemax-sixty
authored
Raise an informative error message when object array has mixed types (#4700)
Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
1 parent e7e8c38 commit dc0931a

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

xarray/conventions.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple:
5252
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
5353

5454

55-
def _infer_dtype(array, name: T_Name = None) -> np.dtype:
56-
"""Given an object array with no missing values, infer its dtype from its
57-
first element
58-
"""
55+
def _infer_dtype(array, name=None):
56+
"""Given an object array with no missing values, infer its dtype from all elements."""
5957
if array.dtype.kind != "O":
6058
raise TypeError("infer_type must be called on a dtype=object array")
6159

6260
if array.size == 0:
6361
return np.dtype(float)
6462

63+
native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
64+
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
65+
raise ValueError(
66+
"unable to infer dtype on variable {!r}; object array "
67+
"contains mixed native types: {}".format(
68+
name, ", ".join(x.__name__ for x in native_dtypes)
69+
)
70+
)
71+
72+
native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
73+
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
74+
raise ValueError(
75+
"unable to infer dtype on variable {!r}; object array "
76+
"contains mixed native types: {}".format(
77+
name, ", ".join(x.__name__ for x in native_dtypes)
78+
)
79+
)
80+
6581
element = array[(0,) * array.ndim]
6682
# We use the base types to avoid subclasses of bytes and str (which might
6783
# not play nice with e.g. hdf5 datatypes), such as those from numpy

xarray/tests/test_conventions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
495495
pass
496496

497497

498+
@pytest.mark.parametrize(
499+
"data",
500+
[
501+
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
502+
np.array([["x", 1], ["y", 2]], dtype="object"),
503+
],
504+
)
505+
def test_infer_dtype_error_on_mixed_types(data):
506+
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
507+
conventions._infer_dtype(data, "test")
508+
509+
498510
class TestDecodeCFVariableWithArrayUnits:
499511
def test_decode_cf_variable_with_array_units(self) -> None:
500512
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})

0 commit comments

Comments
 (0)