Skip to content

Commit eb063e2

Browse files
committed
Add moveaxis
1 parent 250ba86 commit eb063e2

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,15 @@
275275
concat,
276276
expand_dims,
277277
flip,
278+
moveaxis,
278279
permute_dims,
279280
reshape,
280281
roll,
281282
squeeze,
282283
stack,
283284
)
284285

285-
__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
286+
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"]
286287

287288
from ._searching_functions import argmax, argmin, nonzero, where
288289

array_api_strict/_manipulation_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._data_type_functions import result_type
5+
from ._flags import requires_api_version
56

67
from typing import TYPE_CHECKING
78

@@ -43,6 +44,20 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
4344
"""
4445
return Array._new(np.flip(x._array, axis=axis))
4546

47+
@requires_api_version('2023.12')
48+
def moveaxis(
49+
x: Array,
50+
source: Union[int, Tuple[int, ...]],
51+
destination: Union[int, Tuple[int, ...]],
52+
/,
53+
) -> Array:
54+
"""
55+
Array API compatible wrapper for :py:func:`np.moveaxis <numpy.moveaxis>`.
56+
57+
See its docstring for more information.
58+
"""
59+
return Array._new(np.moveaxis(x._array, source, destination))
60+
4661

4762
# Note: The function name is different here (see also matrix_transpose).
4863
# Unlike transpose(), the axes argument is required.

0 commit comments

Comments
 (0)