Skip to content

Commit 1d111b3

Browse files
committed
Add draft implementation for nextafter
1 parent b2e3ecc commit 1d111b3

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
minimum,
173173
multiply,
174174
negative,
175+
nextafter,
175176
not_equal,
176177
positive,
177178
pow,
@@ -240,6 +241,7 @@
240241
"minimum",
241242
"multiply",
242243
"negative",
244+
"nextafter",
243245
"not_equal",
244246
"positive",
245247
"pow",

array_api_strict/_elementwise_functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array:
805805
return Array._new(np.negative(x._array), device=x.device)
806806

807807

808+
@requires_api_version('2024.12')
809+
def nextafter(x1: Array, x2: Array, /) -> Array:
810+
"""
811+
Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`.
812+
813+
See its docstring for more information.
814+
"""
815+
if x1.device != x2.device:
816+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
817+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
818+
raise TypeError("Only real floating-point dtypes are allowed in nextafter")
819+
x1, x2 = Array._normalize_two_args(x1, x2)
820+
return Array._new(np.nextafter(x1._array, x2._array), device=x1.device)
821+
808822
def not_equal(x1: Array, x2: Array, /) -> Array:
809823
"""
810824
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from numpy.testing import assert_raises
44

5+
import pytest
6+
57
from .. import asarray, _elementwise_functions
68
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
79
from .._dtypes import (
@@ -79,6 +81,7 @@ def nargs(func):
7981
"minimum": "real numeric",
8082
"multiply": "numeric",
8183
"negative": "numeric",
84+
"nextafter": "real floating-point",
8285
"not_equal": "all",
8386
"positive": "numeric",
8487
"pow": "numeric",
@@ -126,7 +129,8 @@ def _array_vals(dtypes):
126129
yield asarray(1., dtype=d)
127130

128131
# Use the latest version of the standard so all functions are included
129-
set_array_api_strict_flags(api_version="2023.12")
132+
with pytest.warns(UserWarning):
133+
set_array_api_strict_flags(api_version="2024.12")
130134

131135
for func_name, types in elementwise_function_input_types.items():
132136
dtypes = _dtype_categories[types]
@@ -162,7 +166,8 @@ def _array_vals():
162166
yield asarray(1.0, dtype=d)
163167

164168
# Use the latest version of the standard so all functions are included
165-
set_array_api_strict_flags(api_version="2023.12")
169+
with pytest.warns(UserWarning):
170+
set_array_api_strict_flags(api_version="2024.12")
166171

167172
for x in _array_vals():
168173
for func_name, types in elementwise_function_input_types.items():

array_api_strict/tests/test_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def test_fft(func_name):
284284

285285
api_version_2024_12_examples = {
286286
'diff': lambda: xp.diff(xp.asarray([0, 1, 2])),
287+
'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)),
287288
}
288289

289290
@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys())

0 commit comments

Comments
 (0)