Skip to content

Commit ccdfabb

Browse files
committed
add unstack, moveaxis, swapaxes
1 parent 18bb612 commit ccdfabb

File tree

3 files changed

+183
-0
lines changed

3 files changed

+183
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@
6969
finfo,
7070
flip,
7171
iinfo,
72+
moveaxis,
7273
permute_dims,
7374
result_type,
7475
roll,
7576
squeeze,
7677
stack,
78+
swapaxes,
79+
unstack,
7780
)
7881
from dpctl.tensor._print import (
7982
get_print_options,
@@ -143,6 +146,9 @@
143146
"complex128",
144147
"iinfo",
145148
"finfo",
149+
"unstack",
150+
"moveaxis",
151+
"swapaxes",
146152
"can_cast",
147153
"result_type",
148154
"meshgrid",

dpctl/tensor/_manipulation_functions.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,107 @@ def finfo(dtype):
741741
return finfo_object(dtype)
742742

743743

744+
def unstack(X, axis=0):
745+
"""
746+
Args:
747+
x (usm_ndarray): input array
748+
749+
axis (int): axis along which X is unstacked.
750+
If `X` has rank (i.e, number of dimensions) `N`,
751+
a valid `axis` must reside in the half-open interval `[-N, N)`.
752+
default value is axis=0.
753+
754+
Returns:
755+
out (usm_narray): A tuple of arrays.
756+
757+
Raises:
758+
AxisError: if provided axis position is invalid.
759+
"""
760+
if not isinstance(X, dpt.usm_ndarray):
761+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
762+
763+
axis = normalize_axis_index(axis, X.ndim)
764+
Y = dpt.moveaxis(X, axis, 0)
765+
766+
return tuple(Y[i] for i in range(Y.shape[0]))
767+
768+
769+
def moveaxis(X, src, dst):
770+
"""
771+
Args:
772+
x (usm_ndarray): input array
773+
774+
src (int or a sequence of int): Original positions of the axes to move.
775+
These must be unique. If `X` has rank (i.e., number of dimensions) `N`,
776+
a valid `axis` must reside in the half-open interval `[-N, N)`.
777+
778+
dst (int or a sequence of int): Destination positions for each of the
779+
original axes. These must also be unique. If `X` has rank
780+
(i.e., number of dimensions) `N`, a valid `axis` must reside
781+
in the half-open interval `[-N, N)`.
782+
783+
Returns:
784+
out (usm_narray): Array with moved axes.
785+
The returned array must has the same data type as `X`,
786+
is created on the same device as `X` and has the same USM allocation
787+
type as `X`.
788+
789+
Raises:
790+
AxisError: if provided axis position is invalid.
791+
"""
792+
if not isinstance(X, dpt.usm_ndarray):
793+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
794+
795+
if not isinstance(src, (tuple, list)):
796+
src = (src,)
797+
798+
if not isinstance(dst, (tuple, list)):
799+
dst = (dst,)
800+
801+
src = normalize_axis_tuple(src, X.ndim, "src")
802+
dst = normalize_axis_tuple(dst, X.ndim, "dst")
803+
ind = list(range(0, X.ndim))
804+
for i in range(len(src)):
805+
ind.remove(src[i]) # using the value here which is the same as index
806+
ind.insert(dst[i], src[i])
807+
808+
return dpt.permute_dims(X, tuple(ind))
809+
810+
811+
def swapaxes(X, axis1, axis2):
812+
"""
813+
Args:
814+
x (usm_ndarray): input array
815+
816+
axis1 (int): First axis.
817+
If `X` has rank (i.e., number of dimensions) `N`,
818+
a valid `axis` must reside in the half-open interval `[-N, N)`.
819+
820+
axis2 (int): Second axis.
821+
If `X` has rank (i.e., number of dimensions) `N`,
822+
a valid `axis` must reside in the half-open interval `[-N, N)`.
823+
824+
Returns:
825+
out (usm_narray): Swapped array.
826+
The returned array must has the same data type as `X`,
827+
is created on the same device as `X` and has the same USM allocation
828+
type as `X`.
829+
830+
Raises:
831+
AxisError: if provided axis position is invalid.
832+
"""
833+
if not isinstance(X, dpt.usm_ndarray):
834+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
835+
836+
axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
837+
axis2 = normalize_axis_index(axis2, X.ndim, "axis2")
838+
839+
ind = list(range(0, X.ndim))
840+
ind[axis1] = axis2
841+
ind[axis2] = axis1
842+
return dpt.permute_dims(X, tuple(ind))
843+
844+
744845
def _supported_dtype(dtypes):
745846
for dtype in dtypes:
746847
if dtype.char not in "?bBhHiIlLqQefdFD":

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,3 +1046,79 @@ def test_result_type():
10461046
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"]
10471047

10481048
assert dpt.result_type(*X) == np.result_type(*X_np)
1049+
1050+
1051+
def test_swapaxes_1d():
1052+
x = np.array([[1, 2, 3]])
1053+
exp = np.swapaxes(x, 0, 1)
1054+
1055+
y = dpt.asarray([[1, 2, 3]])
1056+
res = dpt.swapaxes(y, 0, 1)
1057+
1058+
assert_array_equal(exp, dpt.asnumpy(res))
1059+
1060+
1061+
def test_swapaxes_2d():
1062+
x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1063+
exp = np.swapaxes(x, 0, 2)
1064+
1065+
y = dpt.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1066+
res = dpt.swapaxes(y, 0, 2)
1067+
1068+
assert_array_equal(exp, dpt.asnumpy(res))
1069+
1070+
1071+
def test_moveaxis_1axis():
1072+
x = np.arange(60).reshape((3, 4, 5))
1073+
exp = np.moveaxis(x, 0, -1)
1074+
1075+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1076+
res = dpt.moveaxis(y, 0, -1)
1077+
1078+
assert_array_equal(exp, dpt.asnumpy(res))
1079+
1080+
1081+
def test_moveaxis_2axes():
1082+
x = np.arange(60).reshape((3, 4, 5))
1083+
exp = np.moveaxis(x, [0, 1], [-1, -2])
1084+
1085+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1086+
res = dpt.moveaxis(y, [0, 1], [-1, -2])
1087+
1088+
assert_array_equal(exp, dpt.asnumpy(res))
1089+
1090+
1091+
def test_moveaxis_3axes():
1092+
x = np.arange(60).reshape((3, 4, 5))
1093+
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
1094+
1095+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1096+
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])
1097+
1098+
assert_array_equal(exp, dpt.asnumpy(res))
1099+
1100+
1101+
def test_unstack_axis0():
1102+
y = dpt.reshape(dpt.arange(6), (2, 3))
1103+
res = dpt.unstack(y)
1104+
1105+
assert_array_equal(dpt.asnumpy(y[0, ...]), dpt.asnumpy(res[0]))
1106+
assert_array_equal(dpt.asnumpy(y[1, ...]), dpt.asnumpy(res[1]))
1107+
1108+
1109+
def test_unstack_axis1():
1110+
y = dpt.reshape(dpt.arange(6), (2, 3))
1111+
res = dpt.unstack(y, 1)
1112+
1113+
assert_array_equal(dpt.asnumpy(y[:, 0, ...]), dpt.asnumpy(res[0]))
1114+
assert_array_equal(dpt.asnumpy(y[:, 1, ...]), dpt.asnumpy(res[1]))
1115+
assert_array_equal(dpt.asnumpy(y[:, 2, ...]), dpt.asnumpy(res[2]))
1116+
1117+
1118+
def test_unstack_axis2():
1119+
y = dpt.reshape(dpt.arange(60), (4, 5, 3))
1120+
res = dpt.unstack(y, 2)
1121+
1122+
assert_array_equal(dpt.asnumpy(y[:, :, 0, ...]), dpt.asnumpy(res[0]))
1123+
assert_array_equal(dpt.asnumpy(y[:, :, 1, ...]), dpt.asnumpy(res[1]))
1124+
assert_array_equal(dpt.asnumpy(y[:, :, 2, ...]), dpt.asnumpy(res[2]))

0 commit comments

Comments
 (0)