Skip to content

Commit 250ba86

Browse files
committed
Add maximum and minimum
1 parent 3e2d46d commit 250ba86

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

array_api_strict/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
logical_not,
167167
logical_or,
168168
logical_xor,
169+
maximum,
170+
minimum,
169171
multiply,
170172
negative,
171173
not_equal,
@@ -231,6 +233,8 @@
231233
"logical_not",
232234
"logical_or",
233235
"logical_xor",
236+
"maximum",
237+
"minimum",
234238
"multiply",
235239
"negative",
236240
"not_equal",

array_api_strict/_elementwise_functions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,35 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
651651
x1, x2 = Array._normalize_two_args(x1, x2)
652652
return Array._new(np.logical_xor(x1._array, x2._array))
653653

654+
@requires_api_version('2023.12')
655+
def maximum(x1: Array, x2: Array, /) -> Array:
656+
"""
657+
Array API compatible wrapper for :py:func:`np.maximum <numpy.maximum>`.
658+
659+
See its docstring for more information.
660+
"""
661+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
662+
raise TypeError("Only real numeric dtypes are allowed in maximum")
663+
# Call result type here just to raise on disallowed type combinations
664+
_result_type(x1.dtype, x2.dtype)
665+
x1, x2 = Array._normalize_two_args(x1, x2)
666+
# TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error
667+
# in that case?
668+
return Array._new(np.maximum(x1._array, x2._array))
669+
670+
@requires_api_version('2023.12')
671+
def minimum(x1: Array, x2: Array, /) -> Array:
672+
"""
673+
Array API compatible wrapper for :py:func:`np.minimum <numpy.minimum>`.
674+
675+
See its docstring for more information.
676+
"""
677+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
678+
raise TypeError("Only real numeric dtypes are allowed in minimum")
679+
# Call result type here just to raise on disallowed type combinations
680+
_result_type(x1.dtype, x2.dtype)
681+
x1, x2 = Array._normalize_two_args(x1, x2)
682+
return Array._new(np.minimum(x1._array, x2._array))
654683

655684
def multiply(x1: Array, x2: Array, /) -> Array:
656685
"""

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def nargs(func):
6464
"logical_not": "boolean",
6565
"logical_or": "boolean",
6666
"logical_xor": "boolean",
67+
"maximum": "real numeric",
68+
"minimum": "real numeric",
6769
"multiply": "numeric",
6870
"negative": "numeric",
6971
"not_equal": "all",

0 commit comments

Comments
 (0)