Skip to content

Commit 18e680b

Browse files
Fixed docstrings for unstack, swapaxes, moveaxis
1 parent ccdfabb commit 18e680b

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -742,20 +742,23 @@ def finfo(dtype):
742742

743743

744744
def unstack(X, axis=0):
745-
"""
745+
"""unstack(x, axis=0)
746+
747+
Splits an array in a sequence of arrays along the given axis.
748+
746749
Args:
747750
x (usm_ndarray): input array
748751
749-
axis (int): axis along which X is unstacked.
750-
If `X` has rank (i.e, number of dimensions) `N`,
751-
a valid `axis` must reside in the half-open interval `[-N, N)`.
752-
default value is axis=0.
752+
axis (int, optional): axis along which `x` is unstacked.
753+
If `x` has rank (i.e, number of dimensions) `N`,
754+
a valid `axis` must reside in the half-open interval `[-N, N)`.
755+
Default: `0`.
753756
754757
Returns:
755-
out (usm_narray): A tuple of arrays.
758+
Tuple[usm_ndarray,...]: A tuple of arrays.
756759
757760
Raises:
758-
AxisError: if provided axis position is invalid.
761+
AxisError: if the `axis` value is invalid.
759762
"""
760763
if not isinstance(X, dpt.usm_ndarray):
761764
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
@@ -767,27 +770,33 @@ def unstack(X, axis=0):
767770

768771

769772
def moveaxis(X, src, dst):
770-
"""
773+
"""moveaxis(x, src, dst)
774+
775+
Moves axes of an array to new positions.
776+
771777
Args:
772778
x (usm_ndarray): input array
773779
774-
src (int or a sequence of int): Original positions of the axes to move.
775-
These must be unique. If `X` has rank (i.e., number of dimensions) `N`,
776-
a valid `axis` must reside in the half-open interval `[-N, N)`.
780+
src (int or a sequence of int):
781+
Original positions of the axes to move.
782+
These must be unique. If `x` has rank (i.e., number of
783+
dimensions) `N`, a valid `axis` must be in the
784+
half-open interval `[-N, N)`.
777785
778-
dst (int or a sequence of int): Destination positions for each of the
779-
original axes. These must also be unique. If `X` has rank
780-
(i.e., number of dimensions) `N`, a valid `axis` must reside
781-
in the half-open interval `[-N, N)`.
786+
dst (int or a sequence of int):
787+
Destination positions for each of the original axes.
788+
These must also be unique. If `x` has rank
789+
(i.e., number of dimensions) `N`, a valid `axis` must be
790+
in the half-open interval `[-N, N)`.
782791
783792
Returns:
784-
out (usm_narray): Array with moved axes.
785-
The returned array must has the same data type as `X`,
786-
is created on the same device as `X` and has the same USM allocation
787-
type as `X`.
793+
usm_narray: Array with moved axes.
794+
The returned array must has the same data type as `x`,
795+
is created on the same device as `x` and has the same
796+
USM allocation type as `x`.
788797
789798
Raises:
790-
AxisError: if provided axis position is invalid.
799+
AxisError: if `axis` value is invalid.
791800
"""
792801
if not isinstance(X, dpt.usm_ndarray):
793802
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
@@ -809,26 +818,29 @@ def moveaxis(X, src, dst):
809818

810819

811820
def swapaxes(X, axis1, axis2):
812-
"""
821+
"""swapaxes(x, axis1, axis2)
822+
823+
Interchanges two axes of an array.
824+
813825
Args:
814826
x (usm_ndarray): input array
815827
816828
axis1 (int): First axis.
817-
If `X` has rank (i.e., number of dimensions) `N`,
818-
a valid `axis` must reside in the half-open interval `[-N, N)`.
829+
If `x` has rank (i.e., number of dimensions) `N`,
830+
a valid `axis` must be in the half-open interval `[-N, N)`.
819831
820832
axis2 (int): Second axis.
821-
If `X` has rank (i.e., number of dimensions) `N`,
822-
a valid `axis` must reside in the half-open interval `[-N, N)`.
833+
If `x` has rank (i.e., number of dimensions) `N`,
834+
a valid `axis` must be in the half-open interval `[-N, N)`.
823835
824836
Returns:
825-
out (usm_narray): Swapped array.
826-
The returned array must has the same data type as `X`,
827-
is created on the same device as `X` and has the same USM allocation
828-
type as `X`.
837+
usm_narray: Array with swapped axes.
838+
The returned array must has the same data type as `x`,
839+
is created on the same device as `x` and has the same USM
840+
allocation type as `x`.
829841
830842
Raises:
831-
AxisError: if provided axis position is invalid.
843+
AxisError: if `axis` value is invalid.
832844
"""
833845
if not isinstance(X, dpt.usm_ndarray):
834846
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")

0 commit comments

Comments
 (0)