Skip to content

Commit a54eeb7

Browse files
committed
Update numpy.array_api magic methods for complex numbers
Updates from the v2022.12 version of the spec: - Add __complex__. - __float__, __int__, and __bool__ are now more lenient in what dtypes they can operate on. - Support complex scalars and dtypes in all operators (except those that should not operate on complex numbers). - Disallow integer scalars that are out of the bounds of the array dtype. - Update the tests accordingly. Original NumPy Commit: 8b63fc295bafea3efd3d115964200dbaac7be8a5
1 parent 0075777 commit a54eeb7

File tree

3 files changed

+99
-44
lines changed

3 files changed

+99
-44
lines changed

array_api_strict/_array_object.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_integer_dtypes,
2525
_integer_or_boolean_dtypes,
2626
_floating_dtypes,
27+
_complex_floating_dtypes,
2728
_numeric_dtypes,
2829
_result_type,
2930
_dtype_categories,
@@ -139,7 +140,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
139140

140141
if self.dtype not in _dtype_categories[dtype_category]:
141142
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
142-
if isinstance(other, (int, float, bool)):
143+
if isinstance(other, (int, complex, float, bool)):
143144
other = self._promote_scalar(other)
144145
elif isinstance(other, Array):
145146
if other.dtype not in _dtype_categories[dtype_category]:
@@ -189,11 +190,23 @@ def _promote_scalar(self, scalar):
189190
raise TypeError(
190191
"Python int scalars cannot be promoted with bool arrays"
191192
)
193+
if self.dtype in _integer_dtypes:
194+
info = np.iinfo(self.dtype)
195+
if not (info.min <= scalar <= info.max):
196+
raise OverflowError(
197+
"Python int scalars must be within the bounds of the dtype for integer arrays"
198+
)
199+
# int + array(floating) is allowed
192200
elif isinstance(scalar, float):
193201
if self.dtype not in _floating_dtypes:
194202
raise TypeError(
195203
"Python float scalars can only be promoted with floating-point arrays."
196204
)
205+
elif isinstance(scalar, complex):
206+
if self.dtype not in _complex_floating_dtypes:
207+
raise TypeError(
208+
"Python complex scalars can only be promoted with complex floating-point arrays."
209+
)
197210
else:
198211
raise TypeError("'scalar' must be a Python scalar")
199212

@@ -454,11 +467,19 @@ def __bool__(self: Array, /) -> bool:
454467
# Note: This is an error here.
455468
if self._array.ndim != 0:
456469
raise TypeError("bool is only allowed on arrays with 0 dimensions")
457-
if self.dtype not in _boolean_dtypes:
458-
raise ValueError("bool is only allowed on boolean arrays")
459470
res = self._array.__bool__()
460471
return res
461472

473+
def __complex__(self: Array, /) -> float:
474+
"""
475+
Performs the operation __complex__.
476+
"""
477+
# Note: This is an error here.
478+
if self._array.ndim != 0:
479+
raise TypeError("complex is only allowed on arrays with 0 dimensions")
480+
res = self._array.__complex__()
481+
return res
482+
462483
def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
463484
"""
464485
Performs the operation __dlpack__.
@@ -492,16 +513,16 @@ def __float__(self: Array, /) -> float:
492513
# Note: This is an error here.
493514
if self._array.ndim != 0:
494515
raise TypeError("float is only allowed on arrays with 0 dimensions")
495-
if self.dtype not in _floating_dtypes:
496-
raise ValueError("float is only allowed on floating-point arrays")
516+
if self.dtype in _complex_floating_dtypes:
517+
raise TypeError("float is not allowed on complex floating-point arrays")
497518
res = self._array.__float__()
498519
return res
499520

500521
def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
501522
"""
502523
Performs the operation __floordiv__.
503524
"""
504-
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
525+
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
505526
if other is NotImplemented:
506527
return other
507528
self, other = self._normalize_two_args(self, other)
@@ -512,7 +533,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
512533
"""
513534
Performs the operation __ge__.
514535
"""
515-
other = self._check_allowed_dtypes(other, "numeric", "__ge__")
536+
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
516537
if other is NotImplemented:
517538
return other
518539
self, other = self._normalize_two_args(self, other)
@@ -542,7 +563,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
542563
"""
543564
Performs the operation __gt__.
544565
"""
545-
other = self._check_allowed_dtypes(other, "numeric", "__gt__")
566+
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
546567
if other is NotImplemented:
547568
return other
548569
self, other = self._normalize_two_args(self, other)
@@ -556,8 +577,8 @@ def __int__(self: Array, /) -> int:
556577
# Note: This is an error here.
557578
if self._array.ndim != 0:
558579
raise TypeError("int is only allowed on arrays with 0 dimensions")
559-
if self.dtype not in _integer_dtypes:
560-
raise ValueError("int is only allowed on integer arrays")
580+
if self.dtype in _complex_floating_dtypes:
581+
raise TypeError("int is not allowed on complex floating-point arrays")
561582
res = self._array.__int__()
562583
return res
563584

@@ -581,7 +602,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
581602
"""
582603
Performs the operation __le__.
583604
"""
584-
other = self._check_allowed_dtypes(other, "numeric", "__le__")
605+
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
585606
if other is NotImplemented:
586607
return other
587608
self, other = self._normalize_two_args(self, other)
@@ -603,7 +624,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
603624
"""
604625
Performs the operation __lt__.
605626
"""
606-
other = self._check_allowed_dtypes(other, "numeric", "__lt__")
627+
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
607628
if other is NotImplemented:
608629
return other
609630
self, other = self._normalize_two_args(self, other)
@@ -626,7 +647,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
626647
"""
627648
Performs the operation __mod__.
628649
"""
629-
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
650+
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
630651
if other is NotImplemented:
631652
return other
632653
self, other = self._normalize_two_args(self, other)
@@ -808,7 +829,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
808829
"""
809830
Performs the operation __ifloordiv__.
810831
"""
811-
other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
832+
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
812833
if other is NotImplemented:
813834
return other
814835
self._array.__ifloordiv__(other._array)
@@ -818,7 +839,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
818839
"""
819840
Performs the operation __rfloordiv__.
820841
"""
821-
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
842+
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
822843
if other is NotImplemented:
823844
return other
824845
self, other = self._normalize_two_args(self, other)
@@ -874,7 +895,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
874895
"""
875896
Performs the operation __imod__.
876897
"""
877-
other = self._check_allowed_dtypes(other, "numeric", "__imod__")
898+
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
878899
if other is NotImplemented:
879900
return other
880901
self._array.__imod__(other._array)
@@ -884,7 +905,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
884905
"""
885906
Performs the operation __rmod__.
886907
"""
887-
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
908+
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
888909
if other is NotImplemented:
889910
return other
890911
self, other = self._normalize_two_args(self, other)

array_api_strict/_dtypes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_boolean_dtypes = (bool,)
3636
_real_floating_dtypes = (float32, float64)
3737
_floating_dtypes = (float32, float64, complex64, complex128)
38+
_complex_floating_dtypes = (complex64, complex128)
3839
_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
3940
_integer_or_boolean_dtypes = (
4041
bool,
@@ -47,6 +48,18 @@
4748
uint32,
4849
uint64,
4950
)
51+
_real_numeric_dtypes = (
52+
float32,
53+
float64,
54+
int8,
55+
int16,
56+
int32,
57+
int64,
58+
uint8,
59+
uint16,
60+
uint32,
61+
uint64,
62+
)
5063
_numeric_dtypes = (
5164
float32,
5265
float64,
@@ -64,6 +77,7 @@
6477

6578
_dtype_categories = {
6679
"all": _all_dtypes,
80+
"real numeric": _real_numeric_dtypes,
6781
"numeric": _numeric_dtypes,
6882
"integer": _integer_dtypes,
6983
"integer or boolean": _integer_or_boolean_dtypes,
@@ -144,7 +158,7 @@
144158
(complex64, complex64): complex64,
145159
(complex64, complex128): complex128,
146160
(complex128, complex64): complex128,
147-
(complex128, complex64): complex128,
161+
(complex128, complex128): complex128,
148162
(float32, complex64): complex64,
149163
(float32, complex128): complex128,
150164
(float64, complex64): complex128,

array_api_strict/tests/test_array_object.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import operator
22

3-
from numpy.testing import assert_raises
3+
from numpy.testing import assert_raises, suppress_warnings
44
import numpy as np
55
import pytest
66

@@ -9,9 +9,12 @@
99
from numpy._dtypes import (
1010
_all_dtypes,
1111
_boolean_dtypes,
12+
_real_floating_dtypes,
1213
_floating_dtypes,
14+
_complex_floating_dtypes,
1315
_integer_dtypes,
1416
_integer_or_boolean_dtypes,
17+
_real_numeric_dtypes,
1518
_numeric_dtypes,
1619
int8,
1720
int16,
@@ -85,13 +88,13 @@ def test_operators():
8588
"__add__": "numeric",
8689
"__and__": "integer_or_boolean",
8790
"__eq__": "all",
88-
"__floordiv__": "numeric",
89-
"__ge__": "numeric",
90-
"__gt__": "numeric",
91-
"__le__": "numeric",
91+
"__floordiv__": "real numeric",
92+
"__ge__": "real numeric",
93+
"__gt__": "real numeric",
94+
"__le__": "real numeric",
9295
"__lshift__": "integer",
93-
"__lt__": "numeric",
94-
"__mod__": "numeric",
96+
"__lt__": "real numeric",
97+
"__mod__": "real numeric",
9598
"__mul__": "numeric",
9699
"__ne__": "all",
97100
"__or__": "integer_or_boolean",
@@ -101,7 +104,6 @@ def test_operators():
101104
"__truediv__": "floating",
102105
"__xor__": "integer_or_boolean",
103106
}
104-
105107
# Recompute each time because of in-place ops
106108
def _array_vals():
107109
for d in _integer_dtypes:
@@ -111,27 +113,28 @@ def _array_vals():
111113
for d in _floating_dtypes:
112114
yield asarray(1.0, dtype=d)
113115

116+
117+
BIG_INT = int(1e30)
114118
for op, dtypes in binary_op_dtypes.items():
115119
ops = [op]
116120
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
117121
rop = "__r" + op[2:]
118122
iop = "__i" + op[2:]
119123
ops += [rop, iop]
120-
for s in [1, 1.0, False]:
124+
for s in [1, 1.0, 1j, BIG_INT, False]:
121125
for _op in ops:
122126
for a in _array_vals():
123127
# Test array op scalar. From the spec, the following combinations
124128
# are supported:
125129

126130
# - Python bool for a bool array dtype,
127131
# - a Python int within the bounds of the given dtype for integer array dtypes,
128-
# - a Python int or float for floating-point array dtypes
129-
130-
# We do not do bounds checking for int scalars, but rather use the default
131-
# NumPy behavior for casting in that case.
132+
# - a Python int or float for real floating-point array dtypes
133+
# - a Python int, float, or complex for complex floating-point array dtypes
132134

133135
if ((dtypes == "all"
134136
or dtypes == "numeric" and a.dtype in _numeric_dtypes
137+
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
135138
or dtypes == "integer" and a.dtype in _integer_dtypes
136139
or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
137140
or dtypes == "boolean" and a.dtype in _boolean_dtypes
@@ -141,10 +144,18 @@ def _array_vals():
141144
# isinstance here.
142145
and (a.dtype in _boolean_dtypes and type(s) == bool
143146
or a.dtype in _integer_dtypes and type(s) == int
144-
or a.dtype in _floating_dtypes and type(s) in [float, int]
147+
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
148+
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
145149
)):
146-
# Only test for no error
147-
getattr(a, _op)(s)
150+
if a.dtype in _integer_dtypes and s == BIG_INT:
151+
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
152+
else:
153+
# Only test for no error
154+
with suppress_warnings() as sup:
155+
# ignore warnings from pow(BIG_INT)
156+
sup.filter(RuntimeWarning,
157+
"invalid value encountered in power")
158+
getattr(a, _op)(s)
148159
else:
149160
assert_raises(TypeError, lambda: getattr(a, _op)(s))
150161

@@ -174,8 +185,9 @@ def _array_vals():
174185
# Ensure only those dtypes that are required for every operator are allowed.
175186
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
176187
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
188+
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
177189
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
178-
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes
190+
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
179191
or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
180192
or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
181193
or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
@@ -263,31 +275,39 @@ def test_python_scalar_construtors():
263275
b = asarray(False)
264276
i = asarray(0)
265277
f = asarray(0.0)
278+
c = asarray(0j)
266279

267280
assert bool(b) == False
268281
assert int(i) == 0
269282
assert float(f) == 0.0
270283
assert operator.index(i) == 0
271284

272-
# bool/int/float should only be allowed on 0-D arrays.
285+
# bool/int/float/complex should only be allowed on 0-D arrays.
273286
assert_raises(TypeError, lambda: bool(asarray([False])))
274287
assert_raises(TypeError, lambda: int(asarray([0])))
275288
assert_raises(TypeError, lambda: float(asarray([0.0])))
289+
assert_raises(TypeError, lambda: complex(asarray([0j])))
276290
assert_raises(TypeError, lambda: operator.index(asarray([0])))
277291

278-
# bool/int/float should only be allowed on arrays of the corresponding
279-
# dtype
280-
assert_raises(ValueError, lambda: bool(i))
281-
assert_raises(ValueError, lambda: bool(f))
292+
# bool should work on all types of arrays
293+
assert bool(b) is bool(i) is bool(f) is bool(c) is False
294+
295+
# int should fail on complex arrays
296+
assert int(b) == int(i) == int(f) == 0
297+
assert_raises(TypeError, lambda: int(c))
282298

283-
assert_raises(ValueError, lambda: int(b))
284-
assert_raises(ValueError, lambda: int(f))
299+
# float should fail on complex arrays
300+
assert float(b) == float(i) == float(f) == 0.0
301+
assert_raises(TypeError, lambda: float(c))
285302

286-
assert_raises(ValueError, lambda: float(b))
287-
assert_raises(ValueError, lambda: float(i))
303+
# complex should work on all types of arrays
304+
assert complex(b) == complex(i) == complex(f) == complex(c) == 0j
288305

306+
# index should only work on integer arrays
307+
assert operator.index(i) == 0
289308
assert_raises(TypeError, lambda: operator.index(b))
290309
assert_raises(TypeError, lambda: operator.index(f))
310+
assert_raises(TypeError, lambda: operator.index(c))
291311

292312

293313
def test_device_property():

0 commit comments

Comments
 (0)