Skip to content

Commit c4587a4

Browse files
committed
Implement cumulative_sum (still needs to be tested)
1 parent 04c24d7 commit c4587a4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@
290290

291291
__all__ += ["argsort", "sort"]
292292

293-
from ._statistical_functions import max, mean, min, prod, std, sum, var
293+
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
294294

295-
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
295+
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
296296

297297
from ._utility_functions import all, any
298298

array_api_strict/_statistical_functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
)
88
from ._array_object import Array
99
from ._dtypes import float32, complex64
10+
from ._flags import requires_api_version
11+
from ._creation_functions import zeros
12+
from ._manipulation_functions import concat
1013

1114
from typing import TYPE_CHECKING
1215

@@ -16,6 +19,28 @@
1619

1720
import numpy as np
1821

22+
@requires_api_version('2023.12')
23+
def cumulative_sum(
24+
x: Array,
25+
/,
26+
*,
27+
axis: Optional[int] = None,
28+
dtype: Optional[Dtype] = None,
29+
include_initial: bool = False,
30+
) -> Array:
31+
if x.dtype not in _numeric_dtypes:
32+
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
33+
if dtype is None:
34+
dtype = x.dtype
35+
36+
if axis is None:
37+
if x.ndim > 1:
38+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
39+
axis = 0
40+
# np.cumsum does not support include_initial
41+
if include_initial:
42+
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis)
43+
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype._np_dtype))
1944

2045
def max(
2146
x: Array,

0 commit comments

Comments
 (0)