Skip to content

Commit 24a87fa

Browse files
squeeze and expand_dims funcs (#353)
* squeeze func impl and refactoring
1 parent 216eec7 commit 24a87fa

File tree

7 files changed

+297
-148
lines changed

7 files changed

+297
-148
lines changed

dpnp/backend_manipulation.pyx

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ __all__ += [
4040
"dpnp_atleast_2d",
4141
"dpnp_atleast_3d",
4242
"dpnp_copyto",
43+
"dpnp_expand_dims",
4344
"dpnp_repeat",
44-
"dpnp_transpose"
45+
"dpnp_transpose",
46+
"dpnp_squeeze",
4547
]
4648

4749

@@ -129,6 +131,38 @@ cpdef dparray dpnp_copyto(dparray dst, dparray src, where=True):
129131
return dst
130132

131133

134+
cpdef dparray dpnp_expand_dims(dparray in_array, axis):
135+
axis_tuple = _object_to_tuple(axis)
136+
result_ndim = len(axis_tuple) + in_array.ndim
137+
138+
if len(axis_tuple) == 0:
139+
axis_ndim = 0
140+
else:
141+
axis_ndim = max(-min(0, min(axis_tuple)), max(0, max(axis_tuple))) + 1
142+
143+
axis_norm = _object_to_tuple(normalize_axis(axis_tuple, result_ndim))
144+
145+
if axis_ndim - len(axis_norm) > in_array.ndim:
146+
checker_throw_axis_error("dpnp_expand_dims", "axis", axis, axis_ndim)
147+
148+
if len(axis_norm) > len(set(axis_norm)):
149+
checker_throw_value_error("dpnp_expand_dims", "axis", axis, "no repeated axis")
150+
151+
shape_list = []
152+
axis_idx = 0
153+
for i in range(result_ndim):
154+
if i in axis_norm:
155+
shape_list.append(1)
156+
else:
157+
shape_list.append(in_array.shape[axis_idx])
158+
axis_idx = axis_idx + 1
159+
160+
shape = _object_to_tuple(shape_list)
161+
cdef dparray result = dpnp.copy(in_array).reshape(shape)
162+
163+
return result
164+
165+
132166
cpdef dparray dpnp_repeat(dparray array1, repeats, axes=None):
133167
cdef long new_size = array1.size * repeats
134168
cdef dparray result = dparray((new_size, ), dtype=array1.dtype)
@@ -181,3 +215,24 @@ cpdef dparray dpnp_transpose(dparray array1, axes=None):
181215
func(array1.get_data(), input_shape, result_shape, permute_axes, result.get_data(), array1.size)
182216

183217
return result
218+
219+
220+
cpdef dparray dpnp_squeeze(dparray in_array, axis):
221+
shape_list = []
222+
if axis is None:
223+
for i in range(in_array.ndim):
224+
if in_array.shape[i] != 1:
225+
shape_list.append(in_array.shape[i])
226+
else:
227+
axis_norm = _object_to_tuple(normalize_axis(_object_to_tuple(axis), in_array.ndim))
228+
for i in range(in_array.ndim):
229+
if i in axis_norm:
230+
if in_array.shape[i] != 1:
231+
checker_throw_value_error("dpnp_squeeze", "axis", axis, "axis has size not equal to one")
232+
else:
233+
shape_list.append(in_array.shape[i])
234+
235+
shape = _object_to_tuple(shape_list)
236+
cdef dparray result = dpnp.copy(in_array).reshape(shape)
237+
238+
return result

dpnp/dparray.pyx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,17 @@ cdef class dparray:
573573
"""
574574
return std(self, axis, dtype, out, ddof, keepdims)
575575
576+
def squeeze(self, axis=None):
577+
"""
578+
Remove single-dimensional entries from the shape of an array.
579+
580+
.. seealso::
581+
:obj:`dpnp.squeeze` for full documentation
582+
583+
"""
584+
585+
return squeeze(self, axis)
586+
576587
def transpose(self, *axes):
577588
""" Returns a view of the array with axes permuted.
578589

dpnp/dpnp_iface.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
]
6262

6363
from dpnp.dpnp_iface_arraycreation import *
64-
from dpnp.dpnp_iface_arraymanipulation import *
6564
from dpnp.dpnp_iface_bitwise import *
6665
from dpnp.dpnp_iface_counting import *
6766
from dpnp.dpnp_iface_indexing import *
@@ -76,7 +75,6 @@
7675
from dpnp.dpnp_iface_trigonometric import *
7776

7877
from dpnp.dpnp_iface_arraycreation import __all__ as __all__arraycreation
79-
from dpnp.dpnp_iface_arraymanipulation import __all__ as __all__arraymanipulation
8078
from dpnp.dpnp_iface_bitwise import __all__ as __all__bitwise
8179
from dpnp.dpnp_iface_counting import __all__ as __all__counting
8280
from dpnp.dpnp_iface_indexing import __all__ as __all__indexing
@@ -91,7 +89,6 @@
9189
from dpnp.dpnp_iface_trigonometric import __all__ as __all__trigonometric
9290

9391
__all__ += __all__arraycreation
94-
__all__ += __all__arraymanipulation
9592
__all__ += __all__bitwise
9693
__all__ += __all__counting
9794
__all__ += __all__indexing

dpnp/dpnp_iface_arraymanipulation.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

dpnp/dpnp_iface_manipulation.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,52 @@
4747
from dpnp.dparray import dparray
4848
from dpnp.dpnp_utils import *
4949
import dpnp
50+
from dpnp.dpnp_iface_arraycreation import array
5051

5152

5253
__all__ = [
54+
"asfarray",
5355
"atleast_1d",
5456
"atleast_2d",
5557
"atleast_3d",
5658
"copyto",
59+
"expand_dims",
5760
"moveaxis",
5861
"ravel",
5962
"repeat",
6063
"rollaxis",
64+
"squeeze",
6165
"swapaxes",
6266
"transpose"
6367
]
6468

6569

70+
def asfarray(a, dtype=numpy.float64):
71+
"""
72+
Return an array converted to a float type.
73+
74+
For full documentation refer to :obj:`numpy.asfarray`.
75+
76+
Notes
77+
-----
78+
This function works exactly the same as :obj:`dpnp.array`.
79+
80+
"""
81+
82+
if not use_origin_backend(a):
83+
# behavior of original function: int types replaced with float64
84+
if numpy.issubdtype(dtype, numpy.integer):
85+
dtype = numpy.float64
86+
87+
# if type is the same then same object should be returned
88+
if isinstance(a, dpnp.ndarray) and a.dtype == dtype:
89+
return a
90+
91+
return array(a, dtype=dtype)
92+
93+
return call_origin(numpy.asfarray, a, dtype)
94+
95+
6696
def atleast_1d(*arys):
6797
"""
6898
Convert inputs to arrays with at least one dimension.
@@ -193,6 +223,73 @@ def copyto(dst, src, casting='same_kind', where=True):
193223
return result
194224

195225

226+
def expand_dims(a, axis):
227+
"""
228+
Expand the shape of an array.
229+
230+
Insert a new axis that will appear at the `axis` position in the expanded
231+
array shape.
232+
233+
For full documentation refer to :obj:`numpy.expand_dims`.
234+
235+
See Also
236+
--------
237+
:obj:`dpnp.squeeze` : The inverse operation, removing singleton dimensions
238+
:obj:`dpnp.reshape` : Insert, remove, and combine dimensions, and resize existing ones
239+
:obj:`dpnp.indexing`, :obj:`dpnp.atleast_1d`, :obj:`dpnp.atleast_2d`, :obj:`dpnp.atleast_3d`
240+
241+
Examples
242+
--------
243+
>>> import dpnp as np
244+
>>> x = np.array([1, 2])
245+
>>> x.shape
246+
(2,)
247+
248+
The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
249+
250+
>>> y = np.expand_dims(x, axis=0)
251+
>>> y
252+
array([[1, 2]])
253+
>>> y.shape
254+
(1, 2)
255+
256+
The following is equivalent to ``x[:, np.newaxis]``:
257+
258+
>>> y = np.expand_dims(x, axis=1)
259+
>>> y
260+
array([[1],
261+
[2]])
262+
>>> y.shape
263+
(2, 1)
264+
265+
``axis`` may also be a tuple:
266+
267+
>>> y = np.expand_dims(x, axis=(0, 1))
268+
>>> y
269+
array([[[1, 2]]])
270+
271+
>>> y = np.expand_dims(x, axis=(2, 0))
272+
>>> y
273+
array([[[1],
274+
[2]]])
275+
276+
Note that some examples may use ``None`` instead of ``np.newaxis``. These
277+
are the same objects:
278+
279+
>>> np.newaxis is None
280+
True
281+
282+
"""
283+
284+
if not use_origin_backend(a):
285+
if not isinstance(a, dpnp.ndarray):
286+
pass
287+
else:
288+
return dpnp_expand_dims(a, axis)
289+
290+
return call_origin(numpy.expand_dims, a, axis)
291+
292+
196293
def moveaxis(x1, source, destination):
197294
"""
198295
Move axes of an array to new positions. Other axes remain in their original order.
@@ -397,6 +494,49 @@ def rollaxis(a, axis, start=0):
397494
return nd2dp_array(result)
398495

399496

497+
def squeeze(a, axis=None):
498+
"""
499+
Remove single-dimensional entries from the shape of an array.
500+
501+
For full documentation refer to :obj:`numpy.squeeze`.
502+
503+
Examples
504+
--------
505+
>>> import dpnp as np
506+
>>> x = np.array([[[0], [1], [2]]])
507+
>>> x.shape
508+
(1, 3, 1)
509+
>>> np.squeeze(x).shape
510+
(3,)
511+
>>> np.squeeze(x, axis=0).shape
512+
(3, 1)
513+
>>> np.squeeze(x, axis=1).shape
514+
Traceback (most recent call last):
515+
...
516+
ValueError: cannot select an axis to squeeze out which has size not equal to one
517+
>>> np.squeeze(x, axis=2).shape
518+
(1, 3)
519+
>>> x = np.array([[1234]])
520+
>>> x.shape
521+
(1, 1)
522+
>>> np.squeeze(x)
523+
array(1234) # 0d array
524+
>>> np.squeeze(x).shape
525+
()
526+
>>> np.squeeze(x)[()]
527+
1234
528+
529+
"""
530+
531+
if not use_origin_backend(a):
532+
if not isinstance(a, dpnp.ndarray):
533+
pass
534+
else:
535+
return dpnp_squeeze(a, axis)
536+
537+
return call_origin(numpy.squeeze, a, axis)
538+
539+
400540
def swapaxes(x1, axis1, axis2):
401541
"""
402542
Interchange two axes of an array.

0 commit comments

Comments
 (0)