1
- from inspect import getfullargspec
1
+ from inspect import getfullargspec , getmodule
2
2
3
3
from numpy .testing import assert_raises
4
4
10
10
_floating_dtypes ,
11
11
_integer_dtypes ,
12
12
)
13
-
13
+ from .. _flags import set_array_api_strict_flags
14
14
15
15
def nargs (func ):
16
16
return len (getfullargspec (func ).args )
17
17
18
18
19
+ elementwise_function_input_types = {
20
+ "abs" : "numeric" ,
21
+ "acos" : "floating-point" ,
22
+ "acosh" : "floating-point" ,
23
+ "add" : "numeric" ,
24
+ "asin" : "floating-point" ,
25
+ "asinh" : "floating-point" ,
26
+ "atan" : "floating-point" ,
27
+ "atan2" : "real floating-point" ,
28
+ "atanh" : "floating-point" ,
29
+ "bitwise_and" : "integer or boolean" ,
30
+ "bitwise_invert" : "integer or boolean" ,
31
+ "bitwise_left_shift" : "integer" ,
32
+ "bitwise_or" : "integer or boolean" ,
33
+ "bitwise_right_shift" : "integer" ,
34
+ "bitwise_xor" : "integer or boolean" ,
35
+ "ceil" : "real numeric" ,
36
+ "clip" : "real numeric" ,
37
+ "conj" : "complex floating-point" ,
38
+ "copysign" : "real floating-point" ,
39
+ "cos" : "floating-point" ,
40
+ "cosh" : "floating-point" ,
41
+ "divide" : "floating-point" ,
42
+ "equal" : "all" ,
43
+ "exp" : "floating-point" ,
44
+ "expm1" : "floating-point" ,
45
+ "floor" : "real numeric" ,
46
+ "floor_divide" : "real numeric" ,
47
+ "greater" : "real numeric" ,
48
+ "greater_equal" : "real numeric" ,
49
+ "hypot" : "real floating-point" ,
50
+ "imag" : "complex floating-point" ,
51
+ "isfinite" : "numeric" ,
52
+ "isinf" : "numeric" ,
53
+ "isnan" : "numeric" ,
54
+ "less" : "real numeric" ,
55
+ "less_equal" : "real numeric" ,
56
+ "log" : "floating-point" ,
57
+ "logaddexp" : "real floating-point" ,
58
+ "log10" : "floating-point" ,
59
+ "log1p" : "floating-point" ,
60
+ "log2" : "floating-point" ,
61
+ "logical_and" : "boolean" ,
62
+ "logical_not" : "boolean" ,
63
+ "logical_or" : "boolean" ,
64
+ "logical_xor" : "boolean" ,
65
+ "multiply" : "numeric" ,
66
+ "negative" : "numeric" ,
67
+ "not_equal" : "all" ,
68
+ "positive" : "numeric" ,
69
+ "pow" : "numeric" ,
70
+ "real" : "complex floating-point" ,
71
+ "remainder" : "real numeric" ,
72
+ "round" : "numeric" ,
73
+ "sign" : "numeric" ,
74
+ "sin" : "floating-point" ,
75
+ "sinh" : "floating-point" ,
76
+ "sqrt" : "floating-point" ,
77
+ "square" : "numeric" ,
78
+ "subtract" : "numeric" ,
79
+ "tan" : "floating-point" ,
80
+ "tanh" : "floating-point" ,
81
+ "trunc" : "real numeric" ,
82
+ }
83
+
84
+ def test_missing_functions ():
85
+ # Ensure the above dictionary is complete.
86
+ import array_api_strict ._elementwise_functions as mod
87
+ mod_funcs = [n for n in dir (mod ) if getmodule (getattr (mod , n )) is mod ]
88
+ assert set (mod_funcs ) == set (elementwise_function_input_types )
89
+
19
90
def test_function_types ():
20
91
# Test that every function accepts only the required input types. We only
21
92
# test the negative cases here (error). The positive cases are tested in
22
93
# the array API test suite.
23
94
24
- elementwise_function_input_types = {
25
- "abs" : "numeric" ,
26
- "acos" : "floating-point" ,
27
- "acosh" : "floating-point" ,
28
- "add" : "numeric" ,
29
- "asin" : "floating-point" ,
30
- "asinh" : "floating-point" ,
31
- "atan" : "floating-point" ,
32
- "atan2" : "real floating-point" ,
33
- "atanh" : "floating-point" ,
34
- "bitwise_and" : "integer or boolean" ,
35
- "bitwise_invert" : "integer or boolean" ,
36
- "bitwise_left_shift" : "integer" ,
37
- "bitwise_or" : "integer or boolean" ,
38
- "bitwise_right_shift" : "integer" ,
39
- "bitwise_xor" : "integer or boolean" ,
40
- "ceil" : "real numeric" ,
41
- "conj" : "complex floating-point" ,
42
- "cos" : "floating-point" ,
43
- "cosh" : "floating-point" ,
44
- "divide" : "floating-point" ,
45
- "equal" : "all" ,
46
- "exp" : "floating-point" ,
47
- "expm1" : "floating-point" ,
48
- "floor" : "real numeric" ,
49
- "floor_divide" : "real numeric" ,
50
- "greater" : "real numeric" ,
51
- "greater_equal" : "real numeric" ,
52
- "imag" : "complex floating-point" ,
53
- "isfinite" : "numeric" ,
54
- "isinf" : "numeric" ,
55
- "isnan" : "numeric" ,
56
- "less" : "real numeric" ,
57
- "less_equal" : "real numeric" ,
58
- "log" : "floating-point" ,
59
- "logaddexp" : "real floating-point" ,
60
- "log10" : "floating-point" ,
61
- "log1p" : "floating-point" ,
62
- "log2" : "floating-point" ,
63
- "logical_and" : "boolean" ,
64
- "logical_not" : "boolean" ,
65
- "logical_or" : "boolean" ,
66
- "logical_xor" : "boolean" ,
67
- "multiply" : "numeric" ,
68
- "negative" : "numeric" ,
69
- "not_equal" : "all" ,
70
- "positive" : "numeric" ,
71
- "pow" : "numeric" ,
72
- "real" : "complex floating-point" ,
73
- "remainder" : "real numeric" ,
74
- "round" : "numeric" ,
75
- "sign" : "numeric" ,
76
- "sin" : "floating-point" ,
77
- "sinh" : "floating-point" ,
78
- "sqrt" : "floating-point" ,
79
- "square" : "numeric" ,
80
- "subtract" : "numeric" ,
81
- "tan" : "floating-point" ,
82
- "tanh" : "floating-point" ,
83
- "trunc" : "real numeric" ,
84
- }
85
-
86
95
def _array_vals ():
87
96
for d in _integer_dtypes :
88
97
yield asarray (1 , dtype = d )
@@ -91,6 +100,9 @@ def _array_vals():
91
100
for d in _floating_dtypes :
92
101
yield asarray (1.0 , dtype = d )
93
102
103
+ # Use the latest version of the standard so all functions are included
104
+ set_array_api_strict_flags (api_version = "2023.12" )
105
+
94
106
for x in _array_vals ():
95
107
for func_name , types in elementwise_function_input_types .items ():
96
108
dtypes = _dtype_categories [types ]
0 commit comments