Skip to content

Commit 31ceaae

Browse files
committed
Add preliminary diff() function for 2024.12
1 parent 728c69a commit 31ceaae

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@
305305

306306
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
307307

308-
from ._utility_functions import all, any
308+
from ._utility_functions import all, any, diff
309309

310-
__all__ += ["all", "any"]
310+
__all__ += ["all", "any", "diff"]
311311

312312
from ._array_object import Device
313313
__all__ += ["Device"]

array_api_strict/_utility_functions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4+
from ._flags import requires_api_version
45

56
from typing import TYPE_CHECKING
67
if TYPE_CHECKING:
@@ -37,3 +38,25 @@ def any(
3738
See its docstring for more information.
3839
"""
3940
return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device)
41+
42+
@requires_api_version('2024.12')
43+
def diff(
44+
x: Array,
45+
/,
46+
*,
47+
axis: int = -1,
48+
n: int = 1,
49+
prepend: Optional[Array] = None,
50+
append: Optional[Array] = None,
51+
) -> Array:
52+
# NumPy does not support prepend=None or append=None
53+
kwargs = dict(axis=axis, n=n)
54+
if prepend is not None:
55+
if prepend.device != x.device:
56+
raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.")
57+
kwargs['prepend'] = prepend._array
58+
if append is not None:
59+
if append.device != x.device:
60+
raise ValueError(f"Arrays from two different devices ({append.device} and {x.device}) can not be combined.")
61+
kwargs['append'] = append._array
62+
return Array._new(np.diff(x._array, **kwargs), device=x.device)

array_api_strict/tests/test_flags.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def test_fft(func_name):
282282
'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0),
283283
}
284284

285+
api_version_2024_12_examples = {
286+
'diff': lambda: xp.diff(xp.asarray([0, 1, 2])),
287+
}
288+
285289
@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys())
286290
def test_api_version_2023_12(func_name):
287291
func = api_version_2023_12_examples[func_name]
@@ -300,6 +304,28 @@ def test_api_version_2023_12(func_name):
300304
set_array_api_strict_flags(api_version='2022.12')
301305
pytest.raises(RuntimeError, func)
302306

307+
@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())
308+
def test_api_version_2024_12(func_name):
309+
func = api_version_2024_12_examples[func_name]
310+
311+
# By default, these functions should error
312+
pytest.raises(RuntimeError, func)
313+
314+
# In 2022.12 and 2023.12, these functions should error
315+
set_array_api_strict_flags(api_version='2022.12')
316+
pytest.raises(RuntimeError, func)
317+
set_array_api_strict_flags(api_version='2023.12')
318+
pytest.raises(RuntimeError, func)
319+
320+
# They should not error in 2024.12
321+
with pytest.warns(UserWarning):
322+
set_array_api_strict_flags(api_version='2024.12')
323+
func()
324+
325+
# Test the behavior gets updated properly
326+
set_array_api_strict_flags(api_version='2023.12')
327+
pytest.raises(RuntimeError, func)
328+
303329
def test_disabled_extensions():
304330
# Test that xp.extension errors when an extension is disabled, and that
305331
# xp.__all__ is updated properly.

0 commit comments

Comments
 (0)