diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index 4f43f0835..fc0e752b9 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -28,3 +28,4 @@ Objects in API roll squeeze stack + unstack diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2d7179a8b..28ffbcff2 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -199,6 +199,24 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ +def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]: + """ + Splits an array in a sequence of arrays along the given axis. + + Parameters + ---------- + x: array + input array. + axis: int + axis along which the array will be split. A valid ``axis`` must be on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of ``x``. If provided an ``axis`` outside of the required interval, the function must raise an exception. Default: ``0``. + + Returns + ------- + out: Tuple[array, ...] + tuple of slices along the given dimension. All the arrays have the same shape. + """ + + __all__ = [ "broadcast_arrays", "broadcast_to", @@ -210,4 +228,5 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> "roll", "squeeze", "stack", + "unstack", ]