Skip to content

Commit 92ebdff

Browse files
committed
Add cumulative_sum wrapper for the numpy-likes
1 parent 231ef95 commit 92ebdff

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

array_api_compat/common/_aliases.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,36 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268+
# argument
269+
270+
def cumulative_sum(
271+
x: ndarray,
272+
/,
273+
xp,
274+
*,
275+
axis: Optional[int] = None,
276+
dtype: Optional[Dtype] = None,
277+
include_initial: bool = False,
278+
**kwargs
279+
) -> ndarray:
280+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
281+
if axis is None:
282+
if x.ndim > 1:
283+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
284+
axis = 0
285+
286+
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
287+
288+
# np.cumsum does not support include_initial
289+
if include_initial:
290+
initial_shape = list(x.shape)
291+
initial_shape[axis] = 1
292+
res = xp.concatenate(
293+
[xp.zeros_like(res, shape=initial_shape), res],
294+
axis=axis,
295+
)
296+
return res
267297

268298
# The min and max argument names in clip are different and not optional in numpy, and type
269299
# promotion behavior is different.
@@ -502,6 +532,7 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
502532
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
503533
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
504534
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
505-
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape',
506-
'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul',
507-
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack']
535+
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
536+
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
537+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
538+
'unstack']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
astype = _aliases.astype
5050
std = get_xp(cp)(_aliases.std)
5151
var = get_xp(cp)(_aliases.var)
52+
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
5253
clip = get_xp(cp)(_aliases.clip)
5354
permute_dims = get_xp(cp)(_aliases.permute_dims)
5455
reshape = get_xp(cp)(_aliases.reshape)

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _dask_arange(
9191
permute_dims = get_xp(da)(_aliases.permute_dims)
9292
std = get_xp(da)(_aliases.std)
9393
var = get_xp(da)(_aliases.var)
94+
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
9495
empty = get_xp(da)(_aliases.empty)
9596
empty_like = get_xp(da)(_aliases.empty_like)
9697
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
astype = _aliases.astype
5050
std = get_xp(np)(_aliases.std)
5151
var = get_xp(np)(_aliases.var)
52+
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
5253
clip = get_xp(np)(_aliases.clip)
5354
permute_dims = get_xp(np)(_aliases.permute_dims)
5455
reshape = get_xp(np)(_aliases.reshape)

0 commit comments

Comments
 (0)