From f00a882206c86369360f92ca3e550b9ca363d89c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 20:29:50 +0200 Subject: [PATCH 1/5] ENH: add count_nonzero --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_searching_functions.py | 20 +++++++++++++++++++- array_api_strict/tests/test_flags.py | 1 + 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 98b0e95..da66c9e 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -293,9 +293,9 @@ __all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] -from ._searching_functions import argmax, argmin, nonzero, searchsorted, where +from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] +__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 5460b30..df91e44 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Optional, Tuple + from typing import Literal, Optional, Tuple, Union import numpy as np @@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array)) + +@requires_api_version('2024.12') +def count_nonzero( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.count_nonzero ` + + See its docstring for more information. + """ + arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims) + return Array._new(np.asarray(arr), device=x.device) + + @requires_api_version('2023.12') def searchsorted( x1: Array, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index e0b004b..ebee415 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -307,6 +307,7 @@ def test_api_version_2023_12(func_name): 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), xp.zeros((1, 4), dtype=xp.int64)), + 'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)), } @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) From 912362c8b60f9ecb5b37752ec60ebd4382fa3f88 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 21:48:32 +0200 Subject: [PATCH 2/5] ENH: add cumulative_prod (untested) --- array_api_strict/__init__.py | 4 +- array_api_strict/_dtypes.py | 28 +++++++++++++ array_api_strict/_statistical_functions.py | 49 ++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index da66c9e..27dec55 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -305,9 +305,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var -__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any, diff diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index b51ed92..cf7581d 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -127,6 +127,34 @@ def __hash__(self): } +def _bit_width(dtype): + """The bit width of an integer dtype""" + if dtype == int8 or dtype == uint8: + return 8 + elif dtype == int16 or dtype == uint16: + return 16 + elif dtype == int32 or dtype == uint32: + return 32 + elif dtype == int64 or dtype == uint64: + return 64 + else: + raise ValueError(f"_bit_width: {dtype = } not understood.") + + +def _get_unsigned_from_signed(dtype): + """Return an unsigned integral dtype to match the input dtype.""" + if dtype == int8: + return uint8 + elif dtype == int16: + return uint16 + elif dtype == int32: + return uint32 + elif dtype == int64: + return uint64 + else: + raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.") + + # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index f06785c..53f1d2f 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -5,7 +5,10 @@ _real_numeric_dtypes, _floating_dtypes, _numeric_dtypes, + _integer_dtypes ) +from . import _dtypes +from . import _info from ._array_object import Array from ._dtypes import float32, complex64 from ._flags import requires_api_version, get_array_api_strict_flags @@ -47,6 +50,52 @@ def cumulative_sum( x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) + +@requires_api_version('2024.12') +def cumulative_prod( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_prod") + if x.ndim == 0: + raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") + + # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance. + if dtype is None: + if x.dtype in _integer_dtypes: + default_int = _info.__array_namespace_info__().default_dtypes()["integral"] + if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int): + if x.dtype in _dtypes._unsigned_integer_dtypes: + # find the unsigned integer of the same width as `default_int` + dtype = _dtypes._get_unsigned_from_signed(default_int) + else: + dtype = default_int + else: + dtype = x.dtype + else: + dtype = x.dtype + else: + if x.dtype != dtype: + x = xp.astype(dtype) + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + axis = 0 + + # np.cumprod does not support include_initial + if include_initial: + if axis < 0: + axis += x.ndim + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device) + + def max( x: Array, /, From 3ff4ca641b3b4a7323b9627f8f82548928c0ce4d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:02:10 +0200 Subject: [PATCH 3/5] MAINT: simplify cumulative_prod --- array_api_strict/_dtypes.py | 28 ---------------------- array_api_strict/_statistical_functions.py | 28 ++++------------------ 2 files changed, 5 insertions(+), 51 deletions(-) diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index cf7581d..b51ed92 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -127,34 +127,6 @@ def __hash__(self): } -def _bit_width(dtype): - """The bit width of an integer dtype""" - if dtype == int8 or dtype == uint8: - return 8 - elif dtype == int16 or dtype == uint16: - return 16 - elif dtype == int32 or dtype == uint32: - return 32 - elif dtype == int64 or dtype == uint64: - return 64 - else: - raise ValueError(f"_bit_width: {dtype = } not understood.") - - -def _get_unsigned_from_signed(dtype): - """Return an unsigned integral dtype to match the input dtype.""" - if dtype == int8: - return uint8 - elif dtype == int16: - return uint16 - elif dtype == int32: - return uint32 - elif dtype == int64: - return uint64 - else: - raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.") - - # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 53f1d2f..461ee04 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -5,14 +5,11 @@ _real_numeric_dtypes, _floating_dtypes, _numeric_dtypes, - _integer_dtypes ) -from . import _dtypes -from . import _info from ._array_object import Array from ._dtypes import float32, complex64 from ._flags import requires_api_version, get_array_api_strict_flags -from ._creation_functions import zeros +from ._creation_functions import zeros, ones from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -65,23 +62,8 @@ def cumulative_prod( if x.ndim == 0: raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") - # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance. - if dtype is None: - if x.dtype in _integer_dtypes: - default_int = _info.__array_namespace_info__().default_dtypes()["integral"] - if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int): - if x.dtype in _dtypes._unsigned_integer_dtypes: - # find the unsigned integer of the same width as `default_int` - dtype = _dtypes._get_unsigned_from_signed(default_int) - else: - dtype = default_int - else: - dtype = x.dtype - else: - dtype = x.dtype - else: - if x.dtype != dtype: - x = xp.astype(dtype) + if dtype is not None: + dtype = dtype._np_dtype if axis is None: if x.ndim > 1: @@ -92,8 +74,8 @@ def cumulative_prod( if include_initial: if axis < 0: axis += x.ndim - x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) - return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device) + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device) def max( From 83ac04f1bff0750305543b96edbe05575c0a6e36 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:04:00 +0200 Subject: [PATCH 4/5] BUG: fix dtype of include_initial in cumulative_sum In `concat([zeros(...), x])` zeros must have the same dtype as `x`. --- array_api_strict/_statistical_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 461ee04..e41e7ef 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -31,7 +31,6 @@ def cumulative_sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - dt = x.dtype if dtype is None else dtype if dtype is not None: dtype = dtype._np_dtype @@ -44,7 +43,7 @@ def cumulative_sum( if include_initial: if axis < 0: axis += x.ndim - x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) From 61bf3c1e790685a3c0de2ef27e9e105623f351f1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:07:15 +0200 Subject: [PATCH 5/5] TST: add cumulative_prod to test_flags --- array_api_strict/tests/test_flags.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index ebee415..dcfc20d 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -308,6 +308,7 @@ def test_api_version_2023_12(func_name): 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), xp.zeros((1, 4), dtype=xp.int64)), 'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)), + 'cumulative_prod': lambda: xp.cumulative_prod(xp.arange(1, 5)), } @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())