Skip to content

Commit 3fde6a1

Browse files
committed
Add conj, imag, and real functions to numpy.array_api
Original NumPy Commit: 103bca57407ba69632c005609c58538c3a765123
1 parent a54eeb7 commit 3fde6a1

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

array_api_strict/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@
234234
bitwise_right_shift,
235235
bitwise_xor,
236236
ceil,
237+
conj,
237238
cos,
238239
cosh,
239240
divide,
@@ -244,6 +245,7 @@
244245
floor_divide,
245246
greater,
246247
greater_equal,
248+
imag,
247249
isfinite,
248250
isinf,
249251
isnan,
@@ -263,6 +265,7 @@
263265
not_equal,
264266
positive,
265267
pow,
268+
real,
266269
remainder,
267270
round,
268271
sign,

array_api_strict/_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"integer or boolean": _integer_or_boolean_dtypes,
8484
"boolean": _boolean_dtypes,
8585
"real floating-point": _floating_dtypes,
86+
"complex floating-point": _complex_floating_dtypes,
8687
"floating-point": _floating_dtypes,
8788
}
8889

array_api_strict/_elementwise_functions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._dtypes import (
44
_boolean_dtypes,
55
_floating_dtypes,
6+
_complex_floating_dtypes,
67
_integer_dtypes,
78
_integer_or_boolean_dtypes,
89
_numeric_dtypes,
@@ -238,6 +239,17 @@ def ceil(x: Array, /) -> Array:
238239
return Array._new(np.ceil(x._array))
239240

240241

242+
def conj(x: Array, /) -> Array:
243+
"""
244+
Array API compatible wrapper for :py:func:`np.conj <numpy.conj>`.
245+
246+
See its docstring for more information.
247+
"""
248+
if x.dtype not in _complex_floating_dtypes:
249+
raise TypeError("Only complex floating-point dtypes are allowed in conj")
250+
return Array._new(np.conj(x))
251+
252+
241253
def cos(x: Array, /) -> Array:
242254
"""
243255
Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
@@ -364,6 +376,17 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
364376
return Array._new(np.greater_equal(x1._array, x2._array))
365377

366378

379+
def imag(x: Array, /) -> Array:
380+
"""
381+
Array API compatible wrapper for :py:func:`np.imag <numpy.imag>`.
382+
383+
See its docstring for more information.
384+
"""
385+
if x.dtype not in _complex_floating_dtypes:
386+
raise TypeError("Only complex floating-point dtypes are allowed in imag")
387+
return Array._new(np.imag(x))
388+
389+
367390
def isfinite(x: Array, /) -> Array:
368391
"""
369392
Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
@@ -599,6 +622,17 @@ def pow(x1: Array, x2: Array, /) -> Array:
599622
return Array._new(np.power(x1._array, x2._array))
600623

601624

625+
def real(x: Array, /) -> Array:
626+
"""
627+
Array API compatible wrapper for :py:func:`np.real <numpy.real>`.
628+
629+
See its docstring for more information.
630+
"""
631+
if x.dtype not in _complex_floating_dtypes:
632+
raise TypeError("Only complex floating-point dtypes are allowed in real")
633+
return Array._new(np.real(x))
634+
635+
602636
def remainder(x1: Array, x2: Array, /) -> Array:
603637
"""
604638
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_function_types():
3838
"bitwise_right_shift": "integer",
3939
"bitwise_xor": "integer or boolean",
4040
"ceil": "numeric",
41+
"conj": "complex floating-point",
4142
"cos": "floating-point",
4243
"cosh": "floating-point",
4344
"divide": "floating-point",
@@ -48,6 +49,7 @@ def test_function_types():
4849
"floor_divide": "numeric",
4950
"greater": "numeric",
5051
"greater_equal": "numeric",
52+
"imag": "complex floating-point",
5153
"isfinite": "numeric",
5254
"isinf": "numeric",
5355
"isnan": "numeric",
@@ -67,6 +69,7 @@ def test_function_types():
6769
"not_equal": "all",
6870
"positive": "numeric",
6971
"pow": "numeric",
72+
"real": "complex floating-point",
7073
"remainder": "numeric",
7174
"round": "numeric",
7275
"sign": "numeric",

0 commit comments

Comments
 (0)