Skip to content

Commit fbe73ef

Browse files
keewisTomNicholas
andauthored
drop the length from numpy's fixed-width string dtypes (#9586)
* check that the length of fixed-width numpy strings is reset * drop the length from numpy's fixed-width string dtypes * compatibility with `numpy<2` * use `issubdtype` instead * some more test cases * more details in the comment --------- Co-authored-by: Tom Nicholas <tom@cworthy.org>
1 parent f24cae3 commit fbe73ef

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

xarray/core/dtypes.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,17 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
221221
return xp.isdtype(dtype, kind)
222222

223223

224-
def preprocess_scalar_types(t):
224+
def preprocess_types(t):
225225
if isinstance(t, str | bytes):
226226
return type(t)
227+
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and (
228+
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
229+
):
230+
# drop the length from numpy's fixed-width string dtypes, it is better to
231+
# recalculate
232+
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this
233+
# for us
234+
return dtype.type
227235
else:
228236
return t
229237

@@ -255,7 +263,7 @@ def result_type(
255263
xp = get_array_namespace(arrays_and_dtypes)
256264

257265
types = {
258-
array_api_compat.result_type(preprocess_scalar_types(t), xp=xp)
266+
array_api_compat.result_type(preprocess_types(t), xp=xp)
259267
for t in arrays_and_dtypes
260268
}
261269
if any(isinstance(t, np.dtype) for t in types):
@@ -268,5 +276,5 @@ def result_type(
268276
return np.dtype(object)
269277

270278
return array_api_compat.result_type(
271-
*map(preprocess_scalar_types, arrays_and_dtypes), xp=xp
279+
*map(preprocess_types, arrays_and_dtypes), xp=xp
272280
)

xarray/tests/test_dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class DummyArrayAPINamespace:
2828
([np.str_, np.int64], np.object_),
2929
([np.str_, np.str_], np.str_),
3030
([np.bytes_, np.str_], np.object_),
31+
([np.dtype("<U2"), np.str_], np.dtype("U")),
32+
([np.dtype("<U2"), str], np.dtype("U")),
33+
([np.dtype("S3"), np.bytes_], np.dtype("S")),
34+
([np.dtype("S10"), bytes], np.dtype("S")),
3135
],
3236
)
3337
def test_result_type(args, expected) -> None:

0 commit comments

Comments
 (0)