Skip to content

Commit f5fa5c1

Browse files
committed
Add new cumulative_sum function to numpy and array_api namespaces
1 parent 2c85ca6 commit f5fa5c1

File tree

7 files changed

+95
-0
lines changed

7 files changed

+95
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.27
1010

11+
* New Functionality
12+
* Added {func}`jax.numpy.cumulative_sum`, following the addition of this
13+
function in the array API 2023 standard, soon to be adopted by NumPy.
14+
1115
* Changes
1216
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
1317
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover

jax/_src/numpy/lax_numpy.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5567,3 +5567,34 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
55675567
else:
55685568
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
55695569
return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)
5570+
5571+
5572+
@util.implements(getattr(np, 'cumulative_sum', None))
5573+
def cumulative_sum(
5574+
x: Array, /, *, axis: int | None = None,
5575+
dtype: DTypeLike | None = None,
5576+
include_initial: bool = False) -> Array:
5577+
if isscalar(x) or x.ndim == 0:
5578+
raise ValueError(
5579+
"The input must be non-scalar to take a cumulative sum, however a "
5580+
"scalar value or scalar array was given."
5581+
)
5582+
if axis is None and x.ndim > 1:
5583+
raise ValueError(
5584+
f"The input array has rank {x.ndim}, however axis was not set to an "
5585+
"explicit value. The axis argument is only optional for one-dimensional "
5586+
"arrays.")
5587+
axis = axis or 0
5588+
util.check_arraylike("cumulative_sum", x)
5589+
dtypes.check_user_dtype_supported(dtype)
5590+
kind = x.dtype.kind
5591+
if (dtype is None and kind in {'i', 'u'}
5592+
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
5593+
dtype = dtypes.dtype(dtypes._default_types[kind])
5594+
5595+
out = reductions.cumsum(x, axis=axis, dtype=dtype)
5596+
zeros_shape = list(x.shape)
5597+
zeros_shape[axis] = 1
5598+
if include_initial:
5599+
out = concat([zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
5600+
return out

jax/experimental/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
)
196196

197197
from jax.experimental.array_api._statistical_functions import (
198+
cumulative_sum as cumulative_sum,
198199
max as max,
199200
mean as mean,
200201
min as min,

jax/experimental/array_api/_statistical_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919

2020

21+
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
22+
"""Calculates the cumulative sum of elements in the input array x."""
23+
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)
24+
2125
def max(x, /, *, axis=None, keepdims=False):
2226
"""Calculates the maximum value of the input array x."""
2327
return jax.numpy.max(x, axis=axis, keepdims=keepdims)

jax/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
cov as cov,
7979
cross as cross,
8080
csingle as csingle,
81+
cumulative_sum as cumulative_sum,
8182
delete as delete,
8283
diag as diag,
8384
diagflat as diagflat,

tests/array_api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
'conj',
6868
'cos',
6969
'cosh',
70+
'cumulative_sum',
7071
'divide',
7172
'e',
7273
'empty',

tests/lax_numpy_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,59 @@ def np_fun(x):
286286
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
287287
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
288288

289+
290+
@jtu.sample_product(
291+
[dict(shape=shape, axis=axis)
292+
for shape in all_shapes
293+
for axis in list(range(-len(shape), len(shape))) + [None] if len(shape) == 1],
294+
dtype=all_dtypes,
295+
out_dtype=all_dtypes + [None],
296+
include_initial=[False, True],
297+
)
298+
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
299+
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
300+
rng = jtu.rand_some_zero(self.rng())
301+
x = rng(shape, dtype)
302+
out = jnp.cumulative_sum(x, dtype=out_dtype, include_initial=include_initial)
303+
304+
target_dtype = out_dtype or x.dtype
305+
kind = x.dtype.kind
306+
if (out_dtype is None and kind in {'i', 'u'}
307+
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
308+
target_dtype = dtypes.dtype(dtypes._default_types[kind])
309+
assert out.dtype == target_dtype
310+
311+
_axis = axis or 0
312+
target_shape = list(x.shape)
313+
if include_initial:
314+
target_shape[_axis] += 1
315+
assert out.shape == tuple(target_shape)
316+
317+
target = jnp.cumsum(x, axis=_axis, dtype=out.dtype)
318+
if include_initial:
319+
zeros_shape = target_shape
320+
zeros_shape[_axis] = 1
321+
target = jnp.concat([jnp.zeros(target_shape, dtype=out.dtype), target])
322+
self.assertArraysEqual(out, target)
323+
324+
325+
@jtu.sample_product(
326+
shape=all_shapes, dtype=all_dtypes,
327+
include_initial=[False, True])
328+
def testCumulativeSumErrors(self, shape, dtype, include_initial):
329+
rng = jtu.rand_some_zero(self.rng())
330+
x = rng(shape, dtype)
331+
if jnp.isscalar(x) or x.ndim == 0:
332+
msg = r"The input must be non-scalar to take"
333+
with self.assertRaisesRegex(ValueError, msg):
334+
jnp.cumulative_sum(x, include_initial=include_initial)
335+
elif x.ndim > 1:
336+
msg = r"The input array has rank \d*, however"
337+
with self.assertRaisesRegex(ValueError, msg):
338+
jnp.cumulative_sum(x, include_initial=include_initial)
339+
340+
341+
289342
@jtu.sample_product(
290343
[dict(shape=shape, axis=axis)
291344
for shape in all_shapes

0 commit comments

Comments
 (0)