diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 98b0e95..27dec55 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -293,9 +293,9 @@ __all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] -from ._searching_functions import argmax, argmin, nonzero, searchsorted, where +from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] +__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values @@ -305,9 +305,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var -__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any, diff diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 5460b30..df91e44 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Optional, Tuple + from typing import Literal, Optional, Tuple, Union import numpy as np @@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array)) + +@requires_api_version('2024.12') +def count_nonzero( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.count_nonzero ` + + See its docstring for more information. + """ + arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims) + return Array._new(np.asarray(arr), device=x.device) + + @requires_api_version('2023.12') def searchsorted( x1: Array, diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index f06785c..e41e7ef 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -9,7 +9,7 @@ from ._array_object import Array from ._dtypes import float32, complex64 from ._flags import requires_api_version, get_array_api_strict_flags -from ._creation_functions import zeros +from ._creation_functions import zeros, ones from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -31,7 +31,6 @@ def cumulative_sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - dt = x.dtype if dtype is None else dtype if dtype is not None: dtype = dtype._np_dtype @@ -44,9 +43,40 @@ def cumulative_sum( if include_initial: if axis < 0: axis += x.ndim - x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) + +@requires_api_version('2024.12') +def cumulative_prod( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_prod") + if x.ndim == 0: + raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") + + if dtype is not None: + dtype = dtype._np_dtype + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + axis = 0 + + # np.cumprod does not support include_initial + if include_initial: + if axis < 0: + axis += x.ndim + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device) + + def max( x: Array, /, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index e0b004b..dcfc20d 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -307,6 +307,8 @@ def test_api_version_2023_12(func_name): 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), xp.zeros((1, 4), dtype=xp.int64)), + 'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)), + 'cumulative_prod': lambda: xp.cumulative_prod(xp.arange(1, 5)), } @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())