Skip to content

Commit 8f3bfdc

Browse files
committed
Create dh.func_in_dtypes from parsing the spec
1 parent c3798d2 commit 8f3bfdc

File tree

1 file changed

+25
-61
lines changed

1 file changed

+25
-61
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import re
12
from collections.abc import Mapping
23
from functools import lru_cache
3-
from typing import Any, NamedTuple, Sequence, Tuple, Union
4+
from inspect import signature
5+
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
46
from warnings import warn
57

68
from . import _array_module as xp
79
from ._array_module import _UndefinedStub
10+
from .stubs import name_to_func
811
from .typing import DataType, ScalarType
912

1013
__all__ = [
@@ -242,67 +245,28 @@ def result_type(*dtypes: DataType):
242245
return result
243246

244247

245-
func_in_dtypes = {
246-
# elementwise
247-
"abs": numeric_dtypes,
248-
"acos": float_dtypes,
249-
"acosh": float_dtypes,
250-
"add": numeric_dtypes,
251-
"asin": float_dtypes,
252-
"asinh": float_dtypes,
253-
"atan": float_dtypes,
254-
"atan2": float_dtypes,
255-
"atanh": float_dtypes,
256-
"bitwise_and": bool_and_all_int_dtypes,
257-
"bitwise_invert": bool_and_all_int_dtypes,
258-
"bitwise_left_shift": all_int_dtypes,
259-
"bitwise_or": bool_and_all_int_dtypes,
260-
"bitwise_right_shift": all_int_dtypes,
261-
"bitwise_xor": bool_and_all_int_dtypes,
262-
"ceil": numeric_dtypes,
263-
"cos": float_dtypes,
264-
"cosh": float_dtypes,
265-
"divide": float_dtypes,
266-
"equal": all_dtypes,
267-
"exp": float_dtypes,
268-
"expm1": float_dtypes,
269-
"floor": numeric_dtypes,
270-
"floor_divide": numeric_dtypes,
271-
"greater": numeric_dtypes,
272-
"greater_equal": numeric_dtypes,
273-
"isfinite": numeric_dtypes,
274-
"isinf": numeric_dtypes,
275-
"isnan": numeric_dtypes,
276-
"less": numeric_dtypes,
277-
"less_equal": numeric_dtypes,
278-
"log": float_dtypes,
279-
"logaddexp": float_dtypes,
280-
"log10": float_dtypes,
281-
"log1p": float_dtypes,
282-
"log2": float_dtypes,
283-
"logical_and": (xp.bool,),
284-
"logical_not": (xp.bool,),
285-
"logical_or": (xp.bool,),
286-
"logical_xor": (xp.bool,),
287-
"multiply": numeric_dtypes,
288-
"negative": numeric_dtypes,
289-
"not_equal": all_dtypes,
290-
"positive": numeric_dtypes,
291-
"pow": numeric_dtypes,
292-
"remainder": numeric_dtypes,
293-
"round": numeric_dtypes,
294-
"sign": numeric_dtypes,
295-
"sin": float_dtypes,
296-
"sinh": float_dtypes,
297-
"sqrt": float_dtypes,
298-
"square": numeric_dtypes,
299-
"subtract": numeric_dtypes,
300-
"tan": float_dtypes,
301-
"tanh": float_dtypes,
302-
"trunc": numeric_dtypes,
303-
# searching
304-
"where": all_dtypes,
248+
r_in_dtypes = re.compile("x1?: array\n.+Should have an? (.+) data type.")
249+
r_int_note = re.compile(
250+
"If one or both of the input arrays have integer data types, "
251+
"the result is implementation-dependent"
252+
)
253+
category_to_dtypes = {
254+
"boolean": (xp.bool,),
255+
"integer": all_int_dtypes,
256+
"floating-point": float_dtypes,
257+
"numeric": numeric_dtypes,
258+
"integer or boolean": bool_and_all_int_dtypes,
305259
}
260+
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
261+
for name, func in name_to_func.items():
262+
if m := r_in_dtypes.search(func.__doc__):
263+
dtype_category = m.group(1)
264+
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
265+
dtype_category = "floating-point"
266+
dtypes = category_to_dtypes[dtype_category]
267+
func_in_dtypes[name] = dtypes
268+
elif any("x" in name for name in signature(func).parameters.keys()):
269+
func_in_dtypes[name] = all_dtypes
306270

307271

308272
func_returns_bool = {

0 commit comments

Comments
 (0)