Skip to content

Commit 1bfd146

Browse files
committed
Add new cumulative_sum function to numpy and array_api namespaces
1 parent 2c85ca6 commit 1bfd146

File tree

5 files changed

+31
-0
lines changed

5 files changed

+31
-0
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5567,3 +5567,27 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
55675567
else:
55685568
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
55695569
return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)
5570+
5571+
5572+
@util.implements(getattr(np, 'cumulative_sum', None))
5573+
def cumulative_sum(
5574+
x: Array, /, *, axis: int | None = None,
5575+
dtype: DTypeLike | None = None,
5576+
include_initial: bool = False) -> Array:
5577+
if axis is None and x.ndim > 1:
5578+
raise ValueError(
5579+
f"The input array has rank {x.ndim}, however axis was not set to an "
5580+
"explicit value. The axis argument is only optional for one-dimensional "
5581+
"arrays.")
5582+
util.check_arraylike("cumulative_sum", x)
5583+
dtypes.check_user_dtype_supported(dtype)
5584+
kind = x.dtype.kind
5585+
default_dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
5586+
if (dtype is None and kind in {'i', 'u'}
5587+
and x.dtype.itemsize < default_dtype.itemsize):
5588+
dtype = default_dtype
5589+
5590+
out = reductions.cumsum(x, axis=axis, dtype=dtype)
5591+
zeros_shape = list(x.shape)
5592+
zeros_shape[axis if axis else 0] = 1
5593+
return append(zeros(zeros_shape, dtype=out.dtype), out, axis=axis) if include_initial else out

jax/experimental/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
)
196196

197197
from jax.experimental.array_api._statistical_functions import (
198+
cumulative_sum as cumulative_sum,
198199
max as max,
199200
mean as mean,
200201
min as min,

jax/experimental/array_api/_statistical_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919

2020

21+
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
22+
"""Calculates the cumulative sum of elements in the input array x."""
23+
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)
24+
2125
def max(x, /, *, axis=None, keepdims=False):
2226
"""Calculates the maximum value of the input array x."""
2327
return jax.numpy.max(x, axis=axis, keepdims=keepdims)

jax/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
cov as cov,
7979
cross as cross,
8080
csingle as csingle,
81+
cumulative_sum as cumulative_sum,
8182
delete as delete,
8283
diag as diag,
8384
diagflat as diagflat,

tests/array_api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
'conj',
6868
'cos',
6969
'cosh',
70+
'cumulative_sum',
7071
'divide',
7172
'e',
7273
'empty',

0 commit comments

Comments
 (0)