@@ -137,14 +137,12 @@ def permute_dims(X, axes):
137
137
"""
138
138
if not isinstance (X , dpt .usm_ndarray ):
139
139
raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
140
- if not isinstance (axes , (tuple , list )):
141
- axes = (axes ,)
140
+ axes = normalize_axis_tuple (axes , X .ndim , "axes" )
142
141
if not X .ndim == len (axes ):
143
142
raise ValueError (
144
143
"The length of the passed axes does not match "
145
144
"to the number of usm_ndarray dimensions."
146
145
)
147
- axes = normalize_axis_tuple (axes , X .ndim , "axes" )
148
146
newstrides = tuple (X .strides [i ] for i in axes )
149
147
newshape = tuple (X .shape [i ] for i in axes )
150
148
return dpt .usm_ndarray (
@@ -187,7 +185,8 @@ def expand_dims(X, axis):
187
185
"""
188
186
if not isinstance (X , dpt .usm_ndarray ):
189
187
raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
190
- if not isinstance (axis , (tuple , list )):
188
+
189
+ if type (axis ) not in (tuple , list ):
191
190
axis = (axis ,)
192
191
193
192
out_ndim = len (axis ) + X .ndim
@@ -224,8 +223,6 @@ def squeeze(X, axis=None):
224
223
raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
225
224
X_shape = X .shape
226
225
if axis is not None :
227
- if not isinstance (axis , (tuple , list )):
228
- axis = (axis ,)
229
226
axis = normalize_axis_tuple (axis , X .ndim if X .ndim != 0 else X .ndim + 1 )
230
227
new_shape = []
231
228
for i , x in enumerate (X_shape ):
@@ -819,12 +816,6 @@ def moveaxis(X, source, destination):
819
816
if not isinstance (X , dpt .usm_ndarray ):
820
817
raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
821
818
822
- if not isinstance (source , (tuple , list )):
823
- source = (source ,)
824
-
825
- if not isinstance (destination , (tuple , list )):
826
- destination = (destination ,)
827
-
828
819
source = normalize_axis_tuple (source , X .ndim , "source" )
829
820
destination = normalize_axis_tuple (destination , X .ndim , "destination" )
830
821
0 commit comments