Skip to content

Commit f367dfe

Browse files
authored
add kernel for take (#542)
* add kernel for take
1 parent 113e6be commit f367dfe

File tree

8 files changed

+97
-6
lines changed

8 files changed

+97
-6
lines changed

dpnp/backend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ set(DPNP_SRC
166166
kernels/dpnp_krnl_common.cpp
167167
kernels/dpnp_krnl_elemwise.cpp
168168
kernels/dpnp_krnl_fft.cpp
169+
kernels/dpnp_krnl_indexing.cpp
169170
kernels/dpnp_krnl_linalg.cpp
170171
kernels/dpnp_krnl_manipulation.cpp
171172
kernels/dpnp_krnl_mathematical.cpp

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,19 @@ template <typename _DataType, typename _ResultType>
405405
INP_DLLEXPORT void dpnp_std_c(
406406
void* array, void* result, const size_t* shape, size_t ndim, const size_t* axis, size_t naxis, size_t ddof);
407407

408+
/**
409+
* @ingroup BACKEND_API
410+
* @brief math library implementation of take function
411+
*
412+
* @param [in] array Input array with data.
413+
* @param [in] array Input array with indices.
414+
* @param [out] result Output array with indeces.
415+
* @param [in] size Number of elements in the input array.
416+
*/
417+
template <typename _DataType>
418+
INP_DLLEXPORT void dpnp_take_c(
419+
void* array, void* indices, void* result, size_t size);
420+
408421
/**
409422
* @ingroup BACKEND_API
410423
* @brief math library implementation of var function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ enum class DPNPFuncName : size_t
162162
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() implementation */
163163
DPNP_FN_SUM, /**< Used in numpy.sum() implementation */
164164
DPNP_FN_SVD, /**< Used in numpy.linalg.svd() implementation */
165+
DPNP_FN_TAKE, /**< Used in numpy.take() implementation */
165166
DPNP_FN_TAN, /**< Used in numpy.tan() implementation */
166167
DPNP_FN_TANH, /**< Used in numpy.tanh() implementation */
167168
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2016-2020, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <iostream>
27+
#include <list>
28+
29+
#include <dpnp_iface.hpp>
30+
#include "dpnp_fptr.hpp"
31+
#include "dpnp_utils.hpp"
32+
#include "queue_sycl.hpp"
33+
34+
35+
template <typename _DataType>
36+
class dpnp_take_c_kernel;
37+
38+
template <typename _DataType>
39+
void dpnp_take_c(void* array1_in, void* indices1, void* result1, size_t size)
40+
{
41+
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
42+
_DataType* result = reinterpret_cast<_DataType*>(result1);
43+
size_t* indices = reinterpret_cast<size_t*>(indices1);
44+
45+
for (size_t i = 0; i < size; i++)
46+
{
47+
size_t ind = indices[i];
48+
result[i] = array_1[ind];
49+
}
50+
51+
return;
52+
}
53+
54+
55+
void func_map_init_indexing_func(func_map_t& fmap)
56+
{
57+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int>};
58+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long>};
59+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float>};
60+
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double>};
61+
62+
return;
63+
}

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ const DPNPFuncType eft_C128 = DPNPFuncType::DPNP_FT_CMPLX128;
6868
void func_map_init_bitwise(func_map_t& fmap);
6969
void func_map_init_elemwise(func_map_t& fmap);
7070
void func_map_init_fft_func(func_map_t& fmap);
71+
void func_map_init_indexing_func(func_map_t& fmap);
7172
void func_map_init_linalg(func_map_t& fmap);
7273
void func_map_init_linalg_func(func_map_t& fmap);
7374
void func_map_init_manipulation(func_map_t& fmap);

dpnp/backend/src/dpnp_iface_fptr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ static func_map_t func_map_init()
137137
func_map_init_bitwise(fmap);
138138
func_map_init_elemwise(fmap);
139139
func_map_init_fft_func(fmap);
140+
func_map_init_indexing_func(fmap);
140141
func_map_init_linalg(fmap);
141142
func_map_init_linalg_func(fmap);
142143
func_map_init_manipulation(fmap);

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
135135
DPNP_FN_SUBTRACT
136136
DPNP_FN_SUM
137137
DPNP_FN_SVD
138+
DPNP_FN_TAKE
138139
DPNP_FN_TAN
139140
DPNP_FN_TANH
140141
DPNP_FN_TRANSPOSE

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ __all__ += [
5757
]
5858

5959

60+
ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, void * , void * , size_t)
61+
62+
6063
cpdef dparray dpnp_choose(input, choices):
6164
res_array = dparray(len(input), dtype=choices[0].dtype)
6265
for i in range(len(input)):
@@ -259,12 +262,19 @@ cpdef dparray dpnp_select(condlist, choicelist, default):
259262

260263
cpdef dparray dpnp_take(dparray input, dparray indices):
261264
indices_size = indices.size
262-
res_array = dparray(indices_size, dtype=input.dtype)
263-
for i in range(indices_size):
264-
ind = indices[i]
265-
res_array[i] = input[ind]
266-
result = res_array.reshape(indices.shape)
267-
return result
265+
266+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
267+
268+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE, param1_type, param1_type)
269+
270+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
271+
cdef dparray result = dparray(indices_size, dtype=result_type)
272+
273+
cdef custom_indexing_2in_1out_func_ptr_t func = <custom_indexing_2in_1out_func_ptr_t > kernel_data.ptr
274+
275+
func(input.get_data(), indices.get_data(), result.get_data(), indices_size)
276+
277+
return result.reshape(indices.shape)
268278

269279

270280
cpdef tuple dpnp_tril_indices(n, k=0, m=None):

0 commit comments

Comments
 (0)