Skip to content

Commit a452d3c

Browse files
authored
Native repeat (#644)
* add kernel dpnp.repeat
1 parent 0afe971 commit a452d3c

File tree

8 files changed

+88
-28
lines changed

8 files changed

+88
-28
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,18 @@ INP_DLLEXPORT void dpnp_ones_like_c(void* result, size_t size);
821821
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
822822
INP_DLLEXPORT void dpnp_remainder_c(void* array1_in, void* array2_in, void* result1, size_t size);
823823

824+
/**
825+
* @ingroup BACKEND_API
826+
* @brief repeat elements of an array.
827+
*
828+
* @param [in] array_in Input array.
829+
* @param [out] result Output array.
830+
* @param [in] repeats The number of repetitions for each element.
831+
* @param [in] size Number of elements in input arrays.
832+
*/
833+
template <typename _DataType>
834+
INP_DLLEXPORT void dpnp_repeat_c(const void* array_in, void* result, const size_t repeats, const size_t size);
835+
824836
/**
825837
* @ingroup BACKEND_API
826838
* @brief copyto function.

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum class DPNPFuncName : size_t
144144
DPNP_FN_RADIANS, /**< Used in numpy.radians() implementation */
145145
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() implementation */
146146
DPNP_FN_RECIP, /**< Used in numpy.recip() implementation */
147+
DPNP_FN_REPEAT, /**< Used in numpy.repeat() implementation */
147148
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() implementation */
148149
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() implementation */
149150
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() implementation */

dpnp/backend/kernels/dpnp_krnl_manipulation.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,43 @@ void dpnp_copyto_c(void* destination, void* source, const size_t size)
3939
__dpnp_copyto_c<_DataType_src, _DataType_dst>(source, destination, size);
4040
}
4141

42+
template <typename _DataType>
43+
class dpnp_repeat_c_kernel;
44+
45+
template <typename _DataType>
46+
void dpnp_repeat_c(const void* array1_in, void* result1, const size_t repeats, const size_t size)
47+
{
48+
cl::sycl::event event;
49+
50+
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
51+
_DataType* result = reinterpret_cast<_DataType*>(result1);
52+
53+
if (!array_in || !result)
54+
{
55+
return;
56+
}
57+
58+
if (!size || !repeats)
59+
{
60+
return;
61+
}
62+
63+
cl::sycl::range<2> gws(size, repeats);
64+
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
65+
size_t idx1 = global_id[0];
66+
size_t idx2 = global_id[1];
67+
result[(idx1 * repeats) + idx2] = array_in[idx1];
68+
};
69+
70+
auto kernel_func = [&](cl::sycl::handler& cgh) {
71+
cgh.parallel_for<class dpnp_repeat_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
72+
};
73+
74+
event = DPNP_QUEUE.submit(kernel_func);
75+
76+
event.wait();
77+
}
78+
4279
template <typename _KernelNameSpecialization>
4380
class dpnp_elemwise_transpose_c_kernel;
4481

@@ -134,6 +171,11 @@ void func_map_init_manipulation(func_map_t& fmap)
134171
fmap[DPNPFuncName::DPNP_FN_COPYTO][eft_C128][eft_C128] = {
135172
eft_C128, (void*)dpnp_copyto_c<std::complex<double>, std::complex<double>>};
136173

174+
fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_repeat_c<int>};
175+
fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_repeat_c<long>};
176+
fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_repeat_c<float>};
177+
fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_repeat_c<double>};
178+
137179
fmap[DPNPFuncName::DPNP_FN_TRANSPOSE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_elemwise_transpose_c<int>};
138180
fmap[DPNPFuncName::DPNP_FN_TRANSPOSE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_elemwise_transpose_c<long>};
139181
fmap[DPNPFuncName::DPNP_FN_TRANSPOSE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_elemwise_transpose_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
116116
DPNP_FN_RADIANS
117117
DPNP_FN_REMAINDER
118118
DPNP_FN_RECIP
119+
DPNP_FN_REPEAT
119120
DPNP_FN_RIGHT_SHIFT
120121
DPNP_FN_RNG_BETA
121122
DPNP_FN_RNG_BINOMIAL

dpnp/dpnp_algo/dpnp_algo_manipulation.pyx

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ __all__ += [
5050
# C function pointer to the C library template functions
5151
ctypedef void(*fptr_custom_elemwise_transpose_1in_1out_t)(void * , size_t * , size_t * ,
5252
size_t * , size_t, void * , size_t)
53+
ctypedef void(*fptr_dpnp_repeat_t)(const void *, void * , const size_t , const size_t)
5354

5455

5556
cpdef dparray dpnp_atleast_2d(dparray arr):
@@ -129,12 +130,16 @@ cpdef dparray dpnp_expand_dims(dparray in_array, axis):
129130

130131

131132
cpdef dparray dpnp_repeat(dparray array1, repeats, axes=None):
133+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
134+
135+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_REPEAT, param1_type, param1_type)
136+
137+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
132138
cdef long new_size = array1.size * repeats
133139
cdef dparray result = dparray((new_size, ), dtype=array1.dtype)
134140

135-
for idx2 in range(array1.size):
136-
for idx1 in range(repeats):
137-
result[(idx2 * repeats) + idx1] = array1[idx2]
141+
cdef fptr_dpnp_repeat_t func = <fptr_dpnp_repeat_t > kernel_data.ptr
142+
func(array1.get_data(), result.get_data(), repeats, array1.size)
138143

139144
return result
140145

dpnp/dpnp_iface_manipulation.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -403,30 +403,20 @@ def repeat(x1, repeats, axis=None):
403403
404404
"""
405405

406-
is_x1_dparray = isinstance(x1, dparray)
407-
408-
if (not use_origin_backend(x1) and is_x1_dparray and (axis is None or axis == 0) and (x1.ndim < 2)):
409-
410-
repeat_val = repeats
411-
if isinstance(repeats, (tuple, list)):
412-
if (len(repeats) > 1):
413-
checker_throw_value_error("repeat", "len(repeats)", len(repeats), 1)
414-
415-
repeat_val = repeats[0]
416-
417-
return dpnp_repeat(x1, repeat_val, axis)
418-
419-
input1 = dpnp.asnumpy(x1) if is_x1_dparray else x1
420-
421-
# TODO need to put dparray memory into NumPy call
422-
result_numpy = numpy.repeat(input1, repeats, axis=axis)
423-
result = result_numpy
424-
if isinstance(result, numpy.ndarray):
425-
result = dparray(result_numpy.shape, dtype=result_numpy.dtype)
426-
for i in range(result.size):
427-
result._setitem_scalar(i, result_numpy.item(i))
406+
if not use_origin_backend(x1):
407+
if not isinstance(x1, dparray):
408+
pass
409+
elif axis is not None and axis != 0:
410+
pass
411+
elif x1.ndim >= 2:
412+
pass
413+
elif not dpnp.isscalar(repeats) and len(repeats) > 1:
414+
pass
415+
else:
416+
repeat_val = repeats if dpnp.isscalar(repeats) else repeats[0]
417+
return dpnp_repeat(x1, repeat_val, axis)
428418

429-
return result
419+
return call_origin(numpy.repeat, x1, repeats, axis)
430420

431421

432422
def rollaxis(a, axis, start=0):

tests/skipped_tests.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,6 @@ tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shap
974974
tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shape=()}::test_shape_list
975975
tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape
976976
tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape_list
977-
tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeat1D_param_3_{axis=None, repeats=[1, 2, 3, 4]}::test_array_repeat
978-
tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeat1D_param_4_{axis=0, repeats=[1, 2, 3, 4]}::test_array_repeat
979977
tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_func
980978
tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_method
981979
tests/third_party/cupy/manipulation_tests/test_tiling.py::TestTileFailure_param_0_{reps=-1}::test_tile_failure

tests/test_manipulation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,14 @@ def test_copyto_dtype(in_obj, out_dtype):
2121
dpnp.copyto(result, dparr)
2222

2323
numpy.testing.assert_array_equal(result, expected)
24+
25+
26+
@pytest.mark.parametrize("arr",
27+
[[], [1, 2, 3, 4], [[1, 2], [3, 4]], [[[1], [2]], [[3], [4]]]],
28+
ids=['[]', '[1, 2, 3, 4]', '[[1, 2], [3, 4]]', '[[[1], [2]], [[3], [4]]]'])
29+
def test_repeat(arr):
30+
a = numpy.array(arr)
31+
dpnp_a = dpnp.array(arr)
32+
expected = numpy.repeat(a, 2)
33+
result = dpnp.repeat(dpnp_a, 2)
34+
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)