Skip to content

Commit a30536b

Browse files
committed
Add unstack()
1 parent dc1baad commit a30536b

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,10 @@
285285
squeeze,
286286
stack,
287287
tile,
288+
unstack,
288289
)
289290

290-
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile"]
291+
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
291292

292293
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
293294

array_api_strict/_manipulation_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,15 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
167167
if not isinstance(repetitions, tuple):
168168
raise TypeError("repetitions must be a tuple")
169169
return Array._new(np.tile(x._array, repetitions))
170+
171+
# Note: this function is new
172+
@requires_api_version('2023.12')
173+
def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]:
174+
if not (-x.ndim <= axis < x.ndim):
175+
raise ValueError("axis out of range")
176+
177+
if axis < 0:
178+
axis += x.ndim
179+
180+
slices = (slice(None),) * axis
181+
return tuple(x[slices + (i, ...)] for i in range(x.shape[axis]))

0 commit comments

Comments
 (0)