From c99adf8d2fb10c1fc6f241d0ac83f2e3bc0c8f5f Mon Sep 17 00:00:00 2001 From: Stephannie Jimenez Date: Tue, 21 Feb 2023 17:14:41 -0500 Subject: [PATCH 1/4] Add unstack function to the spec --- .../manipulation_functions.rst | 1 + .../_draft/manipulation_functions.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index 4f43f0835..fc0e752b9 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -28,3 +28,4 @@ Objects in API roll squeeze stack + unstack diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2d7179a8b..76002d3b5 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -199,6 +199,24 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ +def unstack(array: array, /, *, axis: int = 0) -> List[array]: + """ + Splits an array in a sequence of arrays along the given axis. + + Parameters + ---------- + array: array + input array. + axis: int + axis along which the array will be split. A valid ``axis`` must be on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of ``x``. If provided an ``axis`` outside of the required interval, the function must raise an exception. Default: ``0``. + + Returns + ------- + out: List[array] + list of slices along the given dimension. All the arrays have the same shape. + """ + + __all__ = [ "broadcast_arrays", "broadcast_to", @@ -210,4 +228,5 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> "roll", "squeeze", "stack", + "unstack", ] From 0b063beb5c61f561085fdb1fe0a472922e7bb131 Mon Sep 17 00:00:00 2001 From: Stephannie Jimenez Date: Thu, 23 Feb 2023 13:58:59 -0500 Subject: [PATCH 2/4] Fix return data type --- src/array_api_stubs/_draft/manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 76002d3b5..547bac7d2 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -199,7 +199,7 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ -def unstack(array: array, /, *, axis: int = 0) -> List[array]: +def unstack(array: array, /, *, axis: int = 0) -> Tuple[array]: """ Splits an array in a sequence of arrays along the given axis. From 562b83006c5dbaebd397f69e9476914e296be088 Mon Sep 17 00:00:00 2001 From: Stephannie Jimenez Date: Fri, 10 Mar 2023 10:28:31 -0500 Subject: [PATCH 3/4] Fix return type and rename parameter name to x --- src/array_api_stubs/_draft/manipulation_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 547bac7d2..02193cfe0 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -199,21 +199,21 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ -def unstack(array: array, /, *, axis: int = 0) -> Tuple[array]: +def unstack(x: Tuple[array, ...], /, *, axis: int = 0) -> Tuple[array]: """ Splits an array in a sequence of arrays along the given axis. Parameters ---------- - array: array + x: array input array. axis: int axis along which the array will be split. A valid ``axis`` must be on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of ``x``. If provided an ``axis`` outside of the required interval, the function must raise an exception. Default: ``0``. Returns ------- - out: List[array] - list of slices along the given dimension. All the arrays have the same shape. + out: Tuple[array] + tuple of slices along the given dimension. All the arrays have the same shape. """ From b34da5a4a6b3546966c15bc9097bc6a5058919a5 Mon Sep 17 00:00:00 2001 From: Stephannie Jimenez Date: Mon, 13 Mar 2023 14:21:02 -0500 Subject: [PATCH 4/4] Fix return type --- src/array_api_stubs/_draft/manipulation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 02193cfe0..28ffbcff2 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -199,7 +199,7 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ -def unstack(x: Tuple[array, ...], /, *, axis: int = 0) -> Tuple[array]: +def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]: """ Splits an array in a sequence of arrays along the given axis. @@ -212,7 +212,7 @@ def unstack(x: Tuple[array, ...], /, *, axis: int = 0) -> Tuple[array]: Returns ------- - out: Tuple[array] + out: Tuple[array, ...] tuple of slices along the given dimension. All the arrays have the same shape. """