Skip to content

Commit c1f27b5

Browse files
committed
Update dtype strictness for complex numbers in array_api elementwise functions
Original NumPy Commit: c866ef19c71b2d0269340ce984be42fd8de45e28
1 parent 520bc70 commit c1f27b5

File tree

2 files changed

+35
-33
lines changed

2 files changed

+35
-33
lines changed

array_api_strict/_elementwise_functions.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from ._dtypes import (
44
_boolean_dtypes,
55
_floating_dtypes,
6+
_real_floating_dtypes,
67
_complex_floating_dtypes,
78
_integer_dtypes,
89
_integer_or_boolean_dtypes,
10+
_real_numeric_dtypes,
911
_numeric_dtypes,
1012
_result_type,
1113
)
@@ -106,8 +108,8 @@ def atan2(x1: Array, x2: Array, /) -> Array:
106108
107109
See its docstring for more information.
108110
"""
109-
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
110-
raise TypeError("Only floating-point dtypes are allowed in atan2")
111+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
112+
raise TypeError("Only real floating-point dtypes are allowed in atan2")
111113
# Call result type here just to raise on disallowed type combinations
112114
_result_type(x1.dtype, x2.dtype)
113115
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -231,8 +233,8 @@ def ceil(x: Array, /) -> Array:
231233
232234
See its docstring for more information.
233235
"""
234-
if x.dtype not in _numeric_dtypes:
235-
raise TypeError("Only numeric dtypes are allowed in ceil")
236+
if x.dtype not in _real_numeric_dtypes:
237+
raise TypeError("Only real numeric dtypes are allowed in ceil")
236238
if x.dtype in _integer_dtypes:
237239
# Note: The return dtype of ceil is the same as the input
238240
return x
@@ -326,8 +328,8 @@ def floor(x: Array, /) -> Array:
326328
327329
See its docstring for more information.
328330
"""
329-
if x.dtype not in _numeric_dtypes:
330-
raise TypeError("Only numeric dtypes are allowed in floor")
331+
if x.dtype not in _real_numeric_dtypes:
332+
raise TypeError("Only real numeric dtypes are allowed in floor")
331333
if x.dtype in _integer_dtypes:
332334
# Note: The return dtype of floor is the same as the input
333335
return x
@@ -340,8 +342,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array:
340342
341343
See its docstring for more information.
342344
"""
343-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
344-
raise TypeError("Only numeric dtypes are allowed in floor_divide")
345+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
346+
raise TypeError("Only real numeric dtypes are allowed in floor_divide")
345347
# Call result type here just to raise on disallowed type combinations
346348
_result_type(x1.dtype, x2.dtype)
347349
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -354,8 +356,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
354356
355357
See its docstring for more information.
356358
"""
357-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
358-
raise TypeError("Only numeric dtypes are allowed in greater")
359+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
360+
raise TypeError("Only real numeric dtypes are allowed in greater")
359361
# Call result type here just to raise on disallowed type combinations
360362
_result_type(x1.dtype, x2.dtype)
361363
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -368,8 +370,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
368370
369371
See its docstring for more information.
370372
"""
371-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
372-
raise TypeError("Only numeric dtypes are allowed in greater_equal")
373+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
374+
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
373375
# Call result type here just to raise on disallowed type combinations
374376
_result_type(x1.dtype, x2.dtype)
375377
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -426,8 +428,8 @@ def less(x1: Array, x2: Array, /) -> Array:
426428
427429
See its docstring for more information.
428430
"""
429-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
430-
raise TypeError("Only numeric dtypes are allowed in less")
431+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
432+
raise TypeError("Only real numeric dtypes are allowed in less")
431433
# Call result type here just to raise on disallowed type combinations
432434
_result_type(x1.dtype, x2.dtype)
433435
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -440,8 +442,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
440442
441443
See its docstring for more information.
442444
"""
443-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
444-
raise TypeError("Only numeric dtypes are allowed in less_equal")
445+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
446+
raise TypeError("Only real numeric dtypes are allowed in less_equal")
445447
# Call result type here just to raise on disallowed type combinations
446448
_result_type(x1.dtype, x2.dtype)
447449
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -498,8 +500,8 @@ def logaddexp(x1: Array, x2: Array) -> Array:
498500
499501
See its docstring for more information.
500502
"""
501-
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
502-
raise TypeError("Only floating-point dtypes are allowed in logaddexp")
503+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
504+
raise TypeError("Only real floating-point dtypes are allowed in logaddexp")
503505
# Call result type here just to raise on disallowed type combinations
504506
_result_type(x1.dtype, x2.dtype)
505507
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -639,8 +641,8 @@ def remainder(x1: Array, x2: Array, /) -> Array:
639641
640642
See its docstring for more information.
641643
"""
642-
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
643-
raise TypeError("Only numeric dtypes are allowed in remainder")
644+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
645+
raise TypeError("Only real numeric dtypes are allowed in remainder")
644646
# Call result type here just to raise on disallowed type combinations
645647
_result_type(x1.dtype, x2.dtype)
646648
x1, x2 = Array._normalize_two_args(x1, x2)
@@ -755,8 +757,8 @@ def trunc(x: Array, /) -> Array:
755757
756758
See its docstring for more information.
757759
"""
758-
if x.dtype not in _numeric_dtypes:
759-
raise TypeError("Only numeric dtypes are allowed in trunc")
760+
if x.dtype not in _real_numeric_dtypes:
761+
raise TypeError("Only real numeric dtypes are allowed in trunc")
760762
if x.dtype in _integer_dtypes:
761763
# Note: The return dtype of trunc is the same as the input
762764
return x

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,34 @@ def test_function_types():
2929
"asin": "floating-point",
3030
"asinh": "floating-point",
3131
"atan": "floating-point",
32-
"atan2": "floating-point",
32+
"atan2": "real floating-point",
3333
"atanh": "floating-point",
3434
"bitwise_and": "integer or boolean",
3535
"bitwise_invert": "integer or boolean",
3636
"bitwise_left_shift": "integer",
3737
"bitwise_or": "integer or boolean",
3838
"bitwise_right_shift": "integer",
3939
"bitwise_xor": "integer or boolean",
40-
"ceil": "numeric",
40+
"ceil": "real numeric",
4141
"conj": "complex floating-point",
4242
"cos": "floating-point",
4343
"cosh": "floating-point",
4444
"divide": "floating-point",
4545
"equal": "all",
4646
"exp": "floating-point",
4747
"expm1": "floating-point",
48-
"floor": "numeric",
49-
"floor_divide": "numeric",
50-
"greater": "numeric",
51-
"greater_equal": "numeric",
48+
"floor": "real numeric",
49+
"floor_divide": "real numeric",
50+
"greater": "real numeric",
51+
"greater_equal": "real numeric",
5252
"imag": "complex floating-point",
5353
"isfinite": "numeric",
5454
"isinf": "numeric",
5555
"isnan": "numeric",
56-
"less": "numeric",
57-
"less_equal": "numeric",
56+
"less": "real numeric",
57+
"less_equal": "real numeric",
5858
"log": "floating-point",
59-
"logaddexp": "floating-point",
59+
"logaddexp": "real floating-point",
6060
"log10": "floating-point",
6161
"log1p": "floating-point",
6262
"log2": "floating-point",
@@ -70,7 +70,7 @@ def test_function_types():
7070
"positive": "numeric",
7171
"pow": "numeric",
7272
"real": "complex floating-point",
73-
"remainder": "numeric",
73+
"remainder": "real numeric",
7474
"round": "numeric",
7575
"sign": "numeric",
7676
"sin": "floating-point",
@@ -80,7 +80,7 @@ def test_function_types():
8080
"subtract": "numeric",
8181
"tan": "floating-point",
8282
"tanh": "floating-point",
83-
"trunc": "numeric",
83+
"trunc": "real numeric",
8484
}
8585

8686
def _array_vals():

0 commit comments

Comments
 (0)