Skip to content

Commit c6a7f75

Browse files
authored
add kernel dpnp.flatten (#645)
* add kernel dpnp.flatten
1 parent 5424ba0 commit c6a7f75

File tree

6 files changed

+41
-5
lines changed

6 files changed

+41
-5
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum class DPNPFuncName : size_t
107107
DPNP_FN_FABS, /**< Used in numpy.fabs() implementation */
108108
DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() implementation */
109109
DPNP_FN_FILL_DIAGONAL, /**< Used in numpy.fill_diagonal() implementation */
110+
DPNP_FN_FLATTEN, /**< Used in numpy.flatten() implementation */
110111
DPNP_FN_FLOOR, /**< Used in numpy.floor() implementation */
111112
DPNP_FN_FLOOR_DIVIDE, /**< Used in numpy.floor_divide() implementation */
112113
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,13 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
309309
fmap[DPNPFuncName::DPNP_FN_ERF][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_erf_c<float>};
310310
fmap[DPNPFuncName::DPNP_FN_ERF][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_erf_c<double>};
311311

312+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_copy_c<bool>};
313+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_copy_c<int>};
314+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_copy_c<long>};
315+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_copy_c<float>};
316+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_copy_c<double>};
317+
fmap[DPNPFuncName::DPNP_FN_FLATTEN][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_copy_c<std::complex<double>>};
318+
312319
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_recip_c<int>};
313320
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_recip_c<long>};
314321
fmap[DPNPFuncName::DPNP_FN_RECIP][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_recip_c<float>};

dpnp/dparray.pyx

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,7 @@ cdef class dparray:
512512
if order == 'F':
513513
return self.transpose().reshape(self.size)
514514
515-
result = dparray(self.size, dtype=self.dtype)
516-
for i in range(result.size):
517-
result[i] = self[i]
518-
519-
return result
515+
return dpnp_flatten(self)
520516
521517
result = utils.dp2nd_array(self).flatten(order=order)
522518

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
8080
DPNP_FN_FABS
8181
DPNP_FN_FFT_FFT
8282
DPNP_FN_FILL_DIAGONAL
83+
DPNP_FN_FLATTEN
8384
DPNP_FN_FLOOR
8485
DPNP_FN_FLOOR_DIVIDE
8586
DPNP_FN_FMOD
@@ -229,6 +230,7 @@ cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, object x1_obj, object x2
229230

230231

231232
cpdef dparray dpnp_astype(dparray array1, dtype_target)
233+
cpdef dparray dpnp_flatten(dparray array1)
232234

233235

234236
"""

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ __all__ = [
4545
"dpnp_arange",
4646
"dpnp_array",
4747
"dpnp_astype",
48+
"dpnp_flatten",
4849
"dpnp_init_val",
4950
"dpnp_matmul",
5051
"dpnp_queue_initialize",
@@ -70,6 +71,7 @@ include "dpnp_algo_trigonometric.pyx"
7071

7172
ctypedef void(*fptr_dpnp_arange_t)(size_t, size_t, void *, size_t)
7273
ctypedef void(*fptr_dpnp_astype_t)(const void *, void * , const size_t)
74+
ctypedef void(*fptr_dpnp_flatten_t)(const void *, void * , const size_t)
7375
ctypedef void(*fptr_dpnp_initval_t)(void *, void * , size_t)
7476

7577

@@ -140,6 +142,20 @@ cpdef dparray dpnp_astype(dparray array1, dtype_target):
140142
return result
141143

142144

145+
cpdef dparray dpnp_flatten(dparray array_):
146+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array_.dtype)
147+
148+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_FLATTEN, param1_type, param1_type)
149+
150+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
151+
cdef dparray result = dparray(array_.size, dtype=result_type)
152+
153+
cdef fptr_dpnp_flatten_t func = <fptr_dpnp_flatten_t > kernel_data.ptr
154+
func(array_.get_data(), result.get_data(), array_.size)
155+
156+
return result
157+
158+
143159
cpdef dparray dpnp_init_val(shape, dtype, value):
144160
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
145161

tests/test_dparray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,17 @@ def test_astype(arr, arr_dtype, res_dtype):
1818
expected = numpy_array.astype(res_dtype)
1919
result = dpnp_array.astype(res_dtype)
2020
numpy.testing.assert_array_equal(expected, result)
21+
22+
23+
@pytest.mark.parametrize("arr_dtype",
24+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool, numpy.bool_, numpy.complex],
25+
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'bool_', 'complex'])
26+
@pytest.mark.parametrize("arr",
27+
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
28+
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
29+
def test_flatten(arr, arr_dtype):
30+
numpy_array = numpy.array(arr, dtype=arr_dtype)
31+
dpnp_array = dpnp.array(arr, dtype=arr_dtype)
32+
expected = numpy_array.flatten()
33+
result = dpnp_array.flatten()
34+
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)