Skip to content

Commit ca6eb1b

Browse files
authored
use sycl_adapter in krnl_indexing (#902)
1 parent 2cfe4a3 commit ca6eb1b

File tree

3 files changed

+68
-45
lines changed

3 files changed

+68
-45
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,14 @@ INP_DLLEXPORT void dpnp_nanvar_c(void* array, void* mask_arr, void* result, size
222222
*
223223
* @param [in] array1 Input array.
224224
* @param [out] result1 Output array.
225+
* @param [in] result_size Output array size.
225226
* @param [in] shape Shape of input array.
226227
* @param [in] ndim Number of elements in shape.
227228
* @param [in] j Number input array.
228229
*/
229230
template <typename _DataType>
230231
INP_DLLEXPORT void
231-
dpnp_nonzero_c(const void* array1, void* result1, const size_t* shape, const size_t ndim, const size_t j);
232+
dpnp_nonzero_c(const void* array1, void* result1, const size_t result_size, const size_t* shape, const size_t ndim, const size_t j);
232233

233234
/**
234235
* @ingroup BACKEND_API
@@ -582,6 +583,7 @@ INP_DLLEXPORT void dpnp_diag_c(
582583
* @brief math library implementation of diagonal function
583584
*
584585
* @param [in] array Input array with data.
586+
* @param [in] input1_size Input1 data size.
585587
* @param [out] result Output array.
586588
* @param [in] offset Offset of the diagonal from the main diagonal.
587589
* @param [in] shape Shape of input array.
@@ -590,7 +592,7 @@ INP_DLLEXPORT void dpnp_diag_c(
590592
*/
591593
template <typename _DataType>
592594
INP_DLLEXPORT void dpnp_diagonal_c(
593-
void* array1_in, void* result1, const size_t offset, size_t* shape, size_t* res_shape, const size_t res_ndim);
595+
void* array1_in, const size_t input1_size, void* result1, const size_t offset, size_t* shape, size_t* res_shape, const size_t res_ndim);
594596

595597
/**
596598
* @ingroup BACKEND_API
@@ -752,12 +754,13 @@ INP_DLLEXPORT void dpnp_std_c(
752754
* @brief math library implementation of take function
753755
*
754756
* @param [in] array Input array with data.
757+
* @param [in] array1_size Input array size.
755758
* @param [in] indices Input array with indices.
756759
* @param [out] result Output array.
757760
* @param [in] size Number of elements in the input array.
758761
*/
759762
template <typename _DataType, typename _IndecesType>
760-
INP_DLLEXPORT void dpnp_take_c(void* array, void* indices, void* result, size_t size);
763+
INP_DLLEXPORT void dpnp_take_c(void* array, const size_t array1_size, void* indices, void* result, size_t size);
761764

762765
/**
763766
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,27 @@
2929

3030
#include <dpnp_iface.hpp>
3131
#include "dpnp_fptr.hpp"
32+
#include "dpnpc_memory_adapter.hpp"
3233
#include "queue_sycl.hpp"
3334

3435
template <typename _DataType>
3536
class dpnp_diagonal_c_kernel;
3637

3738
template <typename _DataType>
3839
void dpnp_diagonal_c(
39-
void* array1_in, void* result1, const size_t offset, size_t* shape, size_t* res_shape, const size_t res_ndim)
40+
void* array1_in, const size_t input1_size, void* result1, const size_t offset, size_t* shape, size_t* res_shape, const size_t res_ndim)
4041
{
41-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
42-
_DataType* result = reinterpret_cast<_DataType*>(result1);
43-
44-
size_t res_size = 1;
45-
for (size_t i = 0; i < res_ndim; ++i)
42+
const size_t res_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<size_t>());
43+
if (!(res_size && input1_size))
4644
{
47-
res_size *= res_shape[i];
45+
return;
4846
}
4947

48+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, input1_size, true);
49+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, res_size, true, true);
50+
_DataType* array_1 = input1_ptr.get_ptr();
51+
_DataType* result = result_ptr.get_ptr();
52+
5053
if (res_ndim <= 1)
5154
{
5255
for (size_t i = 0; i < res_shape[res_ndim - 1]; ++i)
@@ -146,7 +149,14 @@ void dpnp_diagonal_c(
146149
template <typename _DataType>
147150
void dpnp_fill_diagonal_c(void* array1_in, void* val_in, size_t* shape, const size_t ndim)
148151
{
149-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
152+
const size_t result_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
153+
if (!(result_size && array1_in))
154+
{
155+
return;
156+
}
157+
158+
DPNPC_ptr_adapter<_DataType> result_ptr(array1_in, result_size, true, true);
159+
_DataType* array_1 = result_ptr.get_ptr();
150160
_DataType* val_arr = reinterpret_cast<_DataType*>(val_in);
151161

152162
size_t min_shape = shape[0];
@@ -172,11 +182,12 @@ void dpnp_fill_diagonal_c(void* array1_in, void* val_in, size_t* shape, const si
172182
}
173183
array_1[ind] = val;
174184
}
185+
175186
return;
176187
}
177188

178189
template <typename _DataType>
179-
void dpnp_nonzero_c(const void* in_array1, void* result1, const size_t* shape, const size_t ndim, const size_t j)
190+
void dpnp_nonzero_c(const void* in_array1, void* result1, const size_t result_size, const size_t* shape, const size_t ndim, const size_t j)
180191
{
181192
if ((in_array1 == nullptr) || (result1 == nullptr))
182193
{
@@ -188,22 +199,21 @@ void dpnp_nonzero_c(const void* in_array1, void* result1, const size_t* shape, c
188199
return;
189200
}
190201

191-
const _DataType* arr = reinterpret_cast<const _DataType*>(in_array1);
192-
long* result = reinterpret_cast<long*>(result1);
202+
const size_t input1_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
203+
204+
DPNPC_ptr_adapter<_DataType> input1_ptr(in_array1, input1_size, true);
205+
DPNPC_ptr_adapter<long> result_ptr(result1, result_size, true, true);
206+
const _DataType* arr = input1_ptr.get_ptr();
207+
long* result = result_ptr.get_ptr();
193208

194-
size_t size = 1;
195-
for (size_t i = 0; i < ndim; ++i)
196-
{
197-
size *= shape[i];
198-
}
199209

200210
size_t idx = 0;
201-
for (size_t i = 0; i < size; ++i)
211+
for (size_t i = 0; i < input1_size; ++i)
202212
{
203213
if (arr[i] != 0)
204214
{
205215
size_t ids[ndim];
206-
size_t ind1 = size;
216+
size_t ind1 = input1_size;
207217
size_t ind2 = i;
208218
for (size_t k = 0; k < ndim; ++k)
209219
{
@@ -216,6 +226,7 @@ void dpnp_nonzero_c(const void* in_array1, void* result1, const size_t* shape, c
216226
idx += 1;
217227
}
218228
}
229+
219230
return;
220231
}
221232

@@ -226,13 +237,16 @@ void dpnp_place_c(void* arr_in, long* mask_in, void* vals_in, const size_t arr_s
226237
{
227238
return;
228239
}
229-
_DataType* arr = reinterpret_cast<_DataType*>(arr_in);
230240

231241
if (!vals_size)
232242
{
233243
return;
234244
}
235-
_DataType* vals = reinterpret_cast<_DataType*>(vals_in);
245+
246+
DPNPC_ptr_adapter<_DataType> input1_ptr(vals_in, vals_size, true);
247+
DPNPC_ptr_adapter<_DataType> result_ptr(arr_in, arr_size, true, true);
248+
_DataType* vals = input1_ptr.get_ptr();
249+
_DataType* arr = result_ptr.get_ptr();
236250

237251
size_t counter = 0;
238252
for (size_t i = 0; i < arr_size; ++i)
@@ -243,18 +257,15 @@ void dpnp_place_c(void* arr_in, long* mask_in, void* vals_in, const size_t arr_s
243257
counter += 1;
244258
}
245259
}
260+
246261
return;
247262
}
248263

249264
template <typename _DataType, typename _IndecesType, typename _ValueType>
250265
void dpnp_put_c(
251266
void* array1_in, void* ind_in, void* v_in, const size_t size, const size_t size_ind, const size_t size_v)
252267
{
253-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
254-
size_t* ind = reinterpret_cast<size_t*>(ind_in);
255-
_DataType* v = reinterpret_cast<_DataType*>(v_in);
256-
257-
if ((array_1 == nullptr) || (ind == nullptr) || (v == nullptr))
268+
if ((array1_in == nullptr) || (ind_in == nullptr) || (v_in == nullptr))
258269
{
259270
return;
260271
}
@@ -264,6 +275,13 @@ void dpnp_put_c(
264275
return;
265276
}
266277

278+
DPNPC_ptr_adapter<size_t> input1_ptr(ind_in, size_ind, true);
279+
DPNPC_ptr_adapter<_DataType> input2_ptr(v_in, size_v, true);
280+
DPNPC_ptr_adapter<_DataType> result_ptr(array1_in, size, true, true);
281+
size_t* ind = input1_ptr.get_ptr();
282+
_DataType* v = input2_ptr.get_ptr();
283+
_DataType* array_1 = result_ptr.get_ptr();
284+
267285
for (size_t i = 0; i < size; ++i)
268286
{
269287
for (size_t j = 0; j < size_ind; ++j)
@@ -274,6 +292,7 @@ void dpnp_put_c(
274292
}
275293
}
276294
}
295+
277296
return;
278297
}
279298

@@ -287,18 +306,16 @@ void dpnp_put_along_axis_c(void* arr_in,
287306
size_t size_indices,
288307
size_t values_size)
289308
{
290-
_DataType* arr = reinterpret_cast<_DataType*>(arr_in);
291-
size_t* indices = reinterpret_cast<size_t*>(indices_in);
292-
_DataType* values = reinterpret_cast<_DataType*>(values_in);
293-
294309
size_t res_ndim = ndim - 1;
295310
size_t res_shape[res_ndim];
311+
const size_t size_arr = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
296312

297-
size_t size_arr = 1;
298-
for (size_t i = 0; i < ndim; ++i)
299-
{
300-
size_arr *= shape[i];
301-
}
313+
DPNPC_ptr_adapter<size_t> input1_ptr(indices_in, size_indices, true);
314+
DPNPC_ptr_adapter<_DataType> input2_ptr(values_in, values_size, true);
315+
DPNPC_ptr_adapter<_DataType> result_ptr(arr_in, size_arr, true, true);
316+
size_t* indices = input1_ptr.get_ptr();
317+
_DataType* values = input2_ptr.get_ptr();
318+
_DataType* arr = result_ptr.get_ptr();
302319

303320
if (axis != res_ndim)
304321
{
@@ -459,11 +476,13 @@ template <typename _DataType, typename _IndecesType>
459476
class dpnp_take_c_kernel;
460477

461478
template <typename _DataType, typename _IndecesType>
462-
void dpnp_take_c(void* array1_in, void* indices1, void* result1, size_t size)
479+
void dpnp_take_c(void* array1_in, const size_t array1_size, void* indices1, void* result1, size_t size)
463480
{
464-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
481+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, array1_size, true);
482+
DPNPC_ptr_adapter<_IndecesType> input2_ptr(indices1, size);
483+
_DataType* array_1 = input1_ptr.get_ptr();
484+
_IndecesType* indices = input2_ptr.get_ptr();
465485
_DataType* result = reinterpret_cast<_DataType*>(result1);
466-
_IndecesType* indices = reinterpret_cast<_IndecesType*>(indices1);
467486

468487
cl::sycl::range<1> gws(size);
469488
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ __all__ += [
5555
]
5656

5757

58-
ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, void * , void * , size_t)
59-
ctypedef void(*custom_indexing_2in_1out_func_ptr_t_)(void * , void * , const size_t, size_t * , size_t * , const size_t)
58+
ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, const size_t, void * , void * , size_t)
59+
ctypedef void(*custom_indexing_2in_1out_func_ptr_t_)(void * , const size_t, void * , const size_t, size_t * , size_t * , const size_t)
6060
ctypedef void(*custom_indexing_2in_func_ptr_t)(void *, void * , size_t * , const size_t)
6161
ctypedef void(*custom_indexing_3in_func_ptr_t)(void * , void * , void * , const size_t, const size_t)
6262
ctypedef void(*custom_indexing_3in_with_axis_func_ptr_t)(void * , void * , void * , const size_t, size_t * , const size_t,
6363
const size_t, const size_t,)
6464
ctypedef void(*custom_indexing_6in_func_ptr_t)(void *, void * , void * , const size_t, const size_t, const size_t)
65-
ctypedef void(*fptr_dpnp_nonzero_t)(const void * , void * , const size_t * , const size_t , const size_t)
65+
ctypedef void(*fptr_dpnp_nonzero_t)(const void * , void * , const size_t, const size_t * , const size_t , const size_t)
6666

6767

6868
cpdef utils.dpnp_descriptor dpnp_choose(object input, list choices):
@@ -114,6 +114,7 @@ cpdef utils.dpnp_descriptor dpnp_diagonal(dpnp_descriptor input, offset=0):
114114
cdef custom_indexing_2in_1out_func_ptr_t_ func = <custom_indexing_2in_1out_func_ptr_t_ > kernel_data.ptr
115115

116116
func(input.get_data(),
117+
input.size,
117118
result.get_data(),
118119
offset,
119120
< size_t * > input_shape.data(),
@@ -192,7 +193,7 @@ cpdef tuple dpnp_nonzero(utils.dpnp_descriptor in_array1):
192193
result_shape = utils._object_to_tuple(res_size)
193194
res_arr = utils_py.create_output_descriptor_py(result_shape, dpnp.int64, None)
194195

195-
func(in_array1.get_data(), res_arr.get_data(), < size_t * > shape_arr.data(), in_array1.ndim, j)
196+
func(in_array1.get_data(), res_arr.get_data(), res_arr.size, < size_t * > shape_arr.data(), in_array1.ndim, j)
196197

197198
res_list.append(res_arr.get_pyobj())
198199

@@ -295,7 +296,7 @@ cpdef utils.dpnp_descriptor dpnp_take(utils.dpnp_descriptor input, utils.dpnp_de
295296

296297
cdef custom_indexing_2in_1out_func_ptr_t func = <custom_indexing_2in_1out_func_ptr_t > kernel_data.ptr
297298

298-
func(input.get_data(), indices.get_data(), result.get_data(), indices.size)
299+
func(input.get_data(), input.size, indices.get_data(), result.get_data(), indices.size)
299300

300301
return result
301302

0 commit comments

Comments
 (0)