Skip to content

Commit 9ee08c7

Browse files
committed
Update elementwise tests for new elementwise functions
Also add a meta-test to ensure the elementwise tests stay up-to-date.
1 parent b689d43 commit 9ee08c7

File tree

1 file changed

+76
-64
lines changed

1 file changed

+76
-64
lines changed

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 76 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from inspect import getfullargspec
1+
from inspect import getfullargspec, getmodule
22

33
from numpy.testing import assert_raises
44

@@ -10,79 +10,88 @@
1010
_floating_dtypes,
1111
_integer_dtypes,
1212
)
13-
13+
from .._flags import set_array_api_strict_flags
1414

1515
def nargs(func):
1616
return len(getfullargspec(func).args)
1717

1818

19+
elementwise_function_input_types = {
20+
"abs": "numeric",
21+
"acos": "floating-point",
22+
"acosh": "floating-point",
23+
"add": "numeric",
24+
"asin": "floating-point",
25+
"asinh": "floating-point",
26+
"atan": "floating-point",
27+
"atan2": "real floating-point",
28+
"atanh": "floating-point",
29+
"bitwise_and": "integer or boolean",
30+
"bitwise_invert": "integer or boolean",
31+
"bitwise_left_shift": "integer",
32+
"bitwise_or": "integer or boolean",
33+
"bitwise_right_shift": "integer",
34+
"bitwise_xor": "integer or boolean",
35+
"ceil": "real numeric",
36+
"clip": "real numeric",
37+
"conj": "complex floating-point",
38+
"copysign": "real floating-point",
39+
"cos": "floating-point",
40+
"cosh": "floating-point",
41+
"divide": "floating-point",
42+
"equal": "all",
43+
"exp": "floating-point",
44+
"expm1": "floating-point",
45+
"floor": "real numeric",
46+
"floor_divide": "real numeric",
47+
"greater": "real numeric",
48+
"greater_equal": "real numeric",
49+
"hypot": "real floating-point",
50+
"imag": "complex floating-point",
51+
"isfinite": "numeric",
52+
"isinf": "numeric",
53+
"isnan": "numeric",
54+
"less": "real numeric",
55+
"less_equal": "real numeric",
56+
"log": "floating-point",
57+
"logaddexp": "real floating-point",
58+
"log10": "floating-point",
59+
"log1p": "floating-point",
60+
"log2": "floating-point",
61+
"logical_and": "boolean",
62+
"logical_not": "boolean",
63+
"logical_or": "boolean",
64+
"logical_xor": "boolean",
65+
"multiply": "numeric",
66+
"negative": "numeric",
67+
"not_equal": "all",
68+
"positive": "numeric",
69+
"pow": "numeric",
70+
"real": "complex floating-point",
71+
"remainder": "real numeric",
72+
"round": "numeric",
73+
"sign": "numeric",
74+
"sin": "floating-point",
75+
"sinh": "floating-point",
76+
"sqrt": "floating-point",
77+
"square": "numeric",
78+
"subtract": "numeric",
79+
"tan": "floating-point",
80+
"tanh": "floating-point",
81+
"trunc": "real numeric",
82+
}
83+
84+
def test_missing_functions():
85+
# Ensure the above dictionary is complete.
86+
import array_api_strict._elementwise_functions as mod
87+
mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
88+
assert set(mod_funcs) == set(elementwise_function_input_types)
89+
1990
def test_function_types():
2091
# Test that every function accepts only the required input types. We only
2192
# test the negative cases here (error). The positive cases are tested in
2293
# the array API test suite.
2394

24-
elementwise_function_input_types = {
25-
"abs": "numeric",
26-
"acos": "floating-point",
27-
"acosh": "floating-point",
28-
"add": "numeric",
29-
"asin": "floating-point",
30-
"asinh": "floating-point",
31-
"atan": "floating-point",
32-
"atan2": "real floating-point",
33-
"atanh": "floating-point",
34-
"bitwise_and": "integer or boolean",
35-
"bitwise_invert": "integer or boolean",
36-
"bitwise_left_shift": "integer",
37-
"bitwise_or": "integer or boolean",
38-
"bitwise_right_shift": "integer",
39-
"bitwise_xor": "integer or boolean",
40-
"ceil": "real numeric",
41-
"conj": "complex floating-point",
42-
"cos": "floating-point",
43-
"cosh": "floating-point",
44-
"divide": "floating-point",
45-
"equal": "all",
46-
"exp": "floating-point",
47-
"expm1": "floating-point",
48-
"floor": "real numeric",
49-
"floor_divide": "real numeric",
50-
"greater": "real numeric",
51-
"greater_equal": "real numeric",
52-
"imag": "complex floating-point",
53-
"isfinite": "numeric",
54-
"isinf": "numeric",
55-
"isnan": "numeric",
56-
"less": "real numeric",
57-
"less_equal": "real numeric",
58-
"log": "floating-point",
59-
"logaddexp": "real floating-point",
60-
"log10": "floating-point",
61-
"log1p": "floating-point",
62-
"log2": "floating-point",
63-
"logical_and": "boolean",
64-
"logical_not": "boolean",
65-
"logical_or": "boolean",
66-
"logical_xor": "boolean",
67-
"multiply": "numeric",
68-
"negative": "numeric",
69-
"not_equal": "all",
70-
"positive": "numeric",
71-
"pow": "numeric",
72-
"real": "complex floating-point",
73-
"remainder": "real numeric",
74-
"round": "numeric",
75-
"sign": "numeric",
76-
"sin": "floating-point",
77-
"sinh": "floating-point",
78-
"sqrt": "floating-point",
79-
"square": "numeric",
80-
"subtract": "numeric",
81-
"tan": "floating-point",
82-
"tanh": "floating-point",
83-
"trunc": "real numeric",
84-
}
85-
8695
def _array_vals():
8796
for d in _integer_dtypes:
8897
yield asarray(1, dtype=d)
@@ -91,6 +100,9 @@ def _array_vals():
91100
for d in _floating_dtypes:
92101
yield asarray(1.0, dtype=d)
93102

103+
# Use the latest version of the standard so all functions are included
104+
set_array_api_strict_flags(api_version="2023.12")
105+
94106
for x in _array_vals():
95107
for func_name, types in elementwise_function_input_types.items():
96108
dtypes = _dtype_categories[types]

0 commit comments

Comments
 (0)