Skip to content

Commit 92a0a66

Browse files
samir-nasiblishssf
andauthored
ENH: dpnp.concatenate fallback; tests (#811)
Co-authored-by: Sergey Shalnov <shssf@users.noreply.github.com>
1 parent 39de3dc commit 92a0a66

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"atleast_1d",
5757
"atleast_2d",
5858
"atleast_3d",
59+
"concatenate",
5960
"copyto",
6061
"expand_dims",
6162
"hstack",
@@ -178,6 +179,43 @@ def atleast_3d(*arys):
178179
return call_origin(numpy.atleast_3d, *arys)
179180

180181

182+
def concatenate(arrs, axis=0, out=None, dtype=None, casting="same_kind"):
183+
"""
184+
Join a sequence of arrays along an existing axis.
185+
186+
For full documentation refer to :obj:`numpy.concatenate`.
187+
188+
Examples
189+
--------
190+
>>> import dpnp
191+
>>> a = dpnp.array([[1, 2], [3, 4]])
192+
>>> b = dpnp.array([[5, 6]])
193+
>>> res = dpnp.concatenate((a, b), axis=0)
194+
>>> print(res)
195+
[[1 2]
196+
[3 4]
197+
[5 6]]
198+
>>> res = dpnp.concatenate((a, b.T), axis=1)
199+
>>> print(res)
200+
[[1 2 5]
201+
[3 4 6]]
202+
>>> res = dpnp.concatenate((a, b), axis=None)
203+
>>> print(res)
204+
[1 2 3 4 5 6]
205+
206+
"""
207+
208+
# TODO:
209+
# `call_origin` cannot convert sequence of dparray to sequence of
210+
# ndarrays
211+
arrs_new = []
212+
for arr in arrs:
213+
arrx = dpnp.asnumpy(arr) if isinstance(arr, dparray) else arr
214+
arrs_new.append(arrx)
215+
216+
return call_origin(numpy.concatenate, arrs_new, axis=axis, out=out, dtype=dtype, casting=casting)
217+
218+
181219
def copyto(dst, src, casting='same_kind', where=True):
182220
"""
183221
Copies values from one array to another, broadcasting as necessary.

tests/test_arraymanipulation.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,63 @@ def test_asfarray(type, input):
1717
numpy.testing.assert_array_equal(dpnp_res, np_res)
1818

1919

20+
class TestConcatenate:
21+
def test_returns_copy(self):
22+
a = dpnp.array(numpy.eye(3))
23+
b = dpnp.concatenate([a])
24+
b[0, 0] = 2
25+
assert b[0, 0] != a[0, 0]
26+
27+
def test_large_concatenate_axis_None(self):
28+
x = dpnp.arange(1, 100)
29+
r = dpnp.concatenate(x, None)
30+
numpy.testing.assert_array_equal(x, r)
31+
r = dpnp.concatenate(x, 100)
32+
numpy.testing.assert_array_equal(x, r)
33+
34+
def test_concatenate(self):
35+
# Test concatenate function
36+
# One sequence returns unmodified (but as array)
37+
r4 = list(range(4))
38+
numpy.testing.assert_array_equal(dpnp.concatenate((r4,)), r4)
39+
# Any sequence
40+
numpy.testing.assert_array_equal(dpnp.concatenate((tuple(r4),)), r4)
41+
numpy.testing.assert_array_equal(dpnp.concatenate((dpnp.array(r4),)), r4)
42+
# 1D default concatenation
43+
r3 = list(range(3))
44+
numpy.testing.assert_array_equal(dpnp.concatenate((r4, r3)), r4 + r3)
45+
# Mixed sequence types
46+
numpy.testing.assert_array_equal(dpnp.concatenate((tuple(r4), r3)), r4 + r3)
47+
numpy.testing.assert_array_equal(dpnp.concatenate((dpnp.array(r4), r3)), r4 + r3)
48+
# Explicit axis specification
49+
numpy.testing.assert_array_equal(dpnp.concatenate((r4, r3), 0), r4 + r3)
50+
# Including negative
51+
numpy.testing.assert_array_equal(dpnp.concatenate((r4, r3), -1), r4 + r3)
52+
# 2D
53+
a23 = dpnp.array([[10, 11, 12], [13, 14, 15]])
54+
a13 = dpnp.array([[0, 1, 2]])
55+
res = dpnp.array([[10, 11, 12], [13, 14, 15], [0, 1, 2]])
56+
numpy.testing.assert_array_equal(dpnp.concatenate((a23, a13)), res)
57+
numpy.testing.assert_array_equal(dpnp.concatenate((a23, a13), 0), res)
58+
numpy.testing.assert_array_equal(dpnp.concatenate((a23.T, a13.T), 1), res.T)
59+
numpy.testing.assert_array_equal(dpnp.concatenate((a23.T, a13.T), -1), res.T)
60+
# Arrays much match shape
61+
numpy.testing.assert_raises(ValueError, dpnp.concatenate, (a23.T, a13.T), 0)
62+
# 3D
63+
res = dpnp.arange(2 * 3 * 7).reshape((2, 3, 7))
64+
a0 = res[..., :4]
65+
a1 = res[..., 4:6]
66+
a2 = res[..., 6:]
67+
numpy.testing.assert_array_equal(dpnp.concatenate((a0, a1, a2), 2), res)
68+
numpy.testing.assert_array_equal(dpnp.concatenate((a0, a1, a2), -1), res)
69+
numpy.testing.assert_array_equal(dpnp.concatenate((a0.T, a1.T, a2.T), 0), res.T)
70+
71+
out = res.copy()
72+
rout = dpnp.concatenate((a0, a1, a2), 2, out=out)
73+
numpy.testing.assert_(out is rout)
74+
numpy.testing.assert_equal(res, rout)
75+
76+
2077
class TestHstack:
2178
def test_non_iterable(self):
2279
numpy.testing.assert_raises(TypeError, dpnp.hstack, 1)

0 commit comments

Comments
 (0)