Skip to content

Commit 749d5e9

Browse files
Use normalize_axis_tuple to normalize axis for all but expand_dims functions
Closes gh-1178 Adds test to cover gh-1178.
1 parent 37497b2 commit 749d5e9

File tree

2 files changed

+59
-174
lines changed

2 files changed

+59
-174
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,12 @@ def permute_dims(X, axes):
137137
"""
138138
if not isinstance(X, dpt.usm_ndarray):
139139
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
140-
if not isinstance(axes, (tuple, list)):
141-
axes = (axes,)
140+
axes = normalize_axis_tuple(axes, X.ndim, "axes")
142141
if not X.ndim == len(axes):
143142
raise ValueError(
144143
"The length of the passed axes does not match "
145144
"to the number of usm_ndarray dimensions."
146145
)
147-
axes = normalize_axis_tuple(axes, X.ndim, "axes")
148146
newstrides = tuple(X.strides[i] for i in axes)
149147
newshape = tuple(X.shape[i] for i in axes)
150148
return dpt.usm_ndarray(
@@ -187,7 +185,8 @@ def expand_dims(X, axis):
187185
"""
188186
if not isinstance(X, dpt.usm_ndarray):
189187
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
190-
if not isinstance(axis, (tuple, list)):
188+
189+
if type(axis) not in (tuple, list):
191190
axis = (axis,)
192191

193192
out_ndim = len(axis) + X.ndim
@@ -224,8 +223,6 @@ def squeeze(X, axis=None):
224223
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
225224
X_shape = X.shape
226225
if axis is not None:
227-
if not isinstance(axis, (tuple, list)):
228-
axis = (axis,)
229226
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
230227
new_shape = []
231228
for i, x in enumerate(X_shape):
@@ -819,12 +816,6 @@ def moveaxis(X, source, destination):
819816
if not isinstance(X, dpt.usm_ndarray):
820817
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
821818

822-
if not isinstance(source, (tuple, list)):
823-
source = (source,)
824-
825-
if not isinstance(destination, (tuple, list)):
826-
destination = (destination,)
827-
828819
source = normalize_axis_tuple(source, X.ndim, "source")
829820
destination = normalize_axis_tuple(destination, X.ndim, "destination")
830821

0 commit comments

Comments
 (0)