Skip to content

Commit 520bc70

Browse files
committed
Add isdtype() to numpy.array_api
This is a new function in the v2022.12 version of the array API standard which is used for determining if a given dtype is part of a set of given dtype categories. This will also eventually be added to the main NumPy namespace, but for now only exists in numpy.array_api as a purely strict version. Original NumPy Commit: 173fbc7009719ce802aa70634fb93031a0c00cfb
1 parent 94cc065 commit 520bc70

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

array_api_strict/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
broadcast_to,
174174
can_cast,
175175
finfo,
176+
isdtype,
176177
iinfo,
177178
result_type,
178179
)

array_api_strict/_data_type_functions.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4-
from ._dtypes import _all_dtypes, _result_type
4+
from ._dtypes import (
5+
_all_dtypes,
6+
_boolean_dtypes,
7+
_signed_integer_dtypes,
8+
_unsigned_integer_dtypes,
9+
_integer_dtypes,
10+
_real_floating_dtypes,
11+
_complex_floating_dtypes,
12+
_numeric_dtypes,
13+
_result_type,
14+
)
515

616
from dataclasses import dataclass
717
from typing import TYPE_CHECKING, List, Tuple, Union
@@ -117,6 +127,44 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
117127
return iinfo_object(ii.bits, ii.max, ii.min)
118128

119129

130+
# Note: isdtype is a new function from the 2022.12 array API specification.
131+
def isdtype(
132+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
133+
) -> bool:
134+
"""
135+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
136+
137+
See
138+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
139+
for more details
140+
"""
141+
if isinstance(kind, tuple):
142+
# Disallow nested tuples
143+
if any(isinstance(k, tuple) for k in kind):
144+
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
145+
return any(isdtype(dtype, k) for k in kind)
146+
elif isinstance(kind, str):
147+
if kind == 'bool':
148+
return dtype in _boolean_dtypes
149+
elif kind == 'signed integer':
150+
return dtype in _signed_integer_dtypes
151+
elif kind == 'unsigned integer':
152+
return dtype in _unsigned_integer_dtypes
153+
elif kind == 'integral':
154+
return dtype in _integer_dtypes
155+
elif kind == 'real floating':
156+
return dtype in _real_floating_dtypes
157+
elif kind == 'complex floating':
158+
return dtype in _complex_floating_dtypes
159+
elif kind == 'numeric':
160+
return dtype in _numeric_dtypes
161+
else:
162+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
163+
elif kind in _all_dtypes:
164+
return dtype == kind
165+
else:
166+
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
167+
120168
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
121169
"""
122170
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.

array_api_strict/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
_floating_dtypes = (float32, float64, complex64, complex128)
3838
_complex_floating_dtypes = (complex64, complex128)
3939
_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
40+
_signed_integer_dtypes = (int8, int16, int32, int64)
41+
_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64)
4042
_integer_or_boolean_dtypes = (
4143
bool,
4244
int8,

array_api_strict/tests/test_data_type_functions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22

3+
from numpy.testing import assert_raises
34
from numpy import array_api as xp
4-
5+
import numpy as np
56

67
@pytest.mark.parametrize(
78
"from_, to, expected",
@@ -17,3 +18,14 @@ def test_can_cast(from_, to, expected):
1718
can_cast() returns correct result
1819
"""
1920
assert xp.can_cast(from_, to) == expected
21+
22+
def test_isdtype_strictness():
23+
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64))
24+
assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8'))
25+
26+
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),)))
27+
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.object_))
28+
29+
# TODO: These will require https://github.com/numpy/numpy/issues/23883
30+
# assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None))
31+
# assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.float64))

0 commit comments

Comments
 (0)