|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from ._array_object import Array
|
4 |
| -from ._dtypes import _all_dtypes, _result_type |
| 4 | +from ._dtypes import ( |
| 5 | + _all_dtypes, |
| 6 | + _boolean_dtypes, |
| 7 | + _signed_integer_dtypes, |
| 8 | + _unsigned_integer_dtypes, |
| 9 | + _integer_dtypes, |
| 10 | + _real_floating_dtypes, |
| 11 | + _complex_floating_dtypes, |
| 12 | + _numeric_dtypes, |
| 13 | + _result_type, |
| 14 | +) |
5 | 15 |
|
6 | 16 | from dataclasses import dataclass
|
7 | 17 | from typing import TYPE_CHECKING, List, Tuple, Union
|
@@ -117,6 +127,44 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
|
117 | 127 | return iinfo_object(ii.bits, ii.max, ii.min)
|
118 | 128 |
|
119 | 129 |
|
| 130 | +# Note: isdtype is a new function from the 2022.12 array API specification. |
| 131 | +def isdtype( |
| 132 | + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]] |
| 133 | +) -> bool: |
| 134 | + """ |
| 135 | + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. |
| 136 | +
|
| 137 | + See |
| 138 | + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html |
| 139 | + for more details |
| 140 | + """ |
| 141 | + if isinstance(kind, tuple): |
| 142 | + # Disallow nested tuples |
| 143 | + if any(isinstance(k, tuple) for k in kind): |
| 144 | + raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs") |
| 145 | + return any(isdtype(dtype, k) for k in kind) |
| 146 | + elif isinstance(kind, str): |
| 147 | + if kind == 'bool': |
| 148 | + return dtype in _boolean_dtypes |
| 149 | + elif kind == 'signed integer': |
| 150 | + return dtype in _signed_integer_dtypes |
| 151 | + elif kind == 'unsigned integer': |
| 152 | + return dtype in _unsigned_integer_dtypes |
| 153 | + elif kind == 'integral': |
| 154 | + return dtype in _integer_dtypes |
| 155 | + elif kind == 'real floating': |
| 156 | + return dtype in _real_floating_dtypes |
| 157 | + elif kind == 'complex floating': |
| 158 | + return dtype in _complex_floating_dtypes |
| 159 | + elif kind == 'numeric': |
| 160 | + return dtype in _numeric_dtypes |
| 161 | + else: |
| 162 | + raise ValueError(f"Unrecognized data type kind: {kind!r}") |
| 163 | + elif kind in _all_dtypes: |
| 164 | + return dtype == kind |
| 165 | + else: |
| 166 | + raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") |
| 167 | + |
120 | 168 | def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
|
121 | 169 | """
|
122 | 170 | Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
|
|
0 commit comments