Skip to content

Commit 7d54ebe

Browse files
authored
use sycl_adapter in krnl_sorting (#916)
1 parent 270b45c commit 7d54ebe

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <dpnp_iface.hpp>
2929
#include "dpnp_fptr.hpp"
30+
#include "dpnpc_memory_adapter.hpp"
3031
#include "queue_sycl.hpp"
3132

3233
template <typename _DataType, typename _idx_DataType>
@@ -52,8 +53,10 @@ class dpnp_argsort_c_kernel;
5253
template <typename _DataType, typename _idx_DataType>
5354
void dpnp_argsort_c(void* array1_in, void* result1, size_t size)
5455
{
55-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
56-
_idx_DataType* result = reinterpret_cast<_idx_DataType*>(result1);
56+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size, true);
57+
DPNPC_ptr_adapter<_idx_DataType> result1_ptr(result1, size, true, true);
58+
_DataType* array_1 = input1_ptr.get_ptr();
59+
_idx_DataType* result = result1_ptr.get_ptr();
5760

5861
std::iota(result, result + size, 0);
5962

@@ -90,11 +93,7 @@ template <typename _DataType>
9093
void dpnp_partition_c(
9194
void* array1_in, void* array2_in, void* result1, const size_t kth, const size_t* shape_, const size_t ndim)
9295
{
93-
_DataType* arr = reinterpret_cast<_DataType*>(array1_in);
94-
_DataType* arr2 = reinterpret_cast<_DataType*>(array2_in);
95-
_DataType* result = reinterpret_cast<_DataType*>(result1);
96-
97-
if ((arr == nullptr) || (result == nullptr))
96+
if ((array1_in == nullptr) || (array2_in == nullptr) || (result1 == nullptr))
9897
{
9998
return;
10099
}
@@ -104,19 +103,23 @@ void dpnp_partition_c(
104103
return;
105104
}
106105

107-
size_t size = 1;
108-
for (size_t i = 0; i < ndim; ++i)
109-
{
110-
size *= shape_[i];
111-
}
112-
106+
const size_t size = std::accumulate(shape_, shape_ + ndim, 1, std::multiplies<size_t>());
113107
size_t size_ = size / shape_[ndim - 1];
114108

115109
if (size_ == 0)
116110
{
117111
return;
118112
}
119113

114+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size, true);
115+
DPNPC_ptr_adapter<_DataType> input2_ptr(array2_in, size, true);
116+
DPNPC_ptr_adapter<_DataType> result1_ptr(result1, size, true, true);
117+
_DataType* arr = input1_ptr.get_ptr();
118+
_DataType* arr2 = input2_ptr.get_ptr();
119+
_DataType* result = result1_ptr.get_ptr();
120+
121+
122+
120123
auto arr_to_result_event = DPNP_QUEUE.memcpy(result, arr, size * sizeof(_DataType));
121124
arr_to_result_event.wait();
122125

@@ -182,11 +185,7 @@ template <typename _DataType, typename _IndexingType>
182185
void dpnp_searchsorted_c(
183186
void* result1, const void* array1_in, const void* v1_in, bool side, const size_t arr_size, const size_t v_size)
184187
{
185-
const _DataType* arr = reinterpret_cast<const _DataType*>(array1_in);
186-
const _DataType* v = reinterpret_cast<const _DataType*>(v1_in);
187-
_IndexingType* result = reinterpret_cast<_IndexingType*>(result1);
188-
189-
if ((arr == nullptr) || (v == nullptr) || (result == nullptr))
188+
if ((array1_in == nullptr) || (v1_in == nullptr) || (result1 == nullptr))
190189
{
191190
return;
192191
}
@@ -201,6 +200,12 @@ void dpnp_searchsorted_c(
201200
return;
202201
}
203202

203+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, arr_size);
204+
DPNPC_ptr_adapter<_DataType> input2_ptr(v1_in, v_size);
205+
const _DataType* arr = input1_ptr.get_ptr();
206+
const _DataType* v = input2_ptr.get_ptr();
207+
_IndexingType* result = reinterpret_cast<_IndexingType*>(result1);
208+
204209
cl::sycl::range<2> gws(v_size, arr_size);
205210
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
206211
size_t i = global_id[0];
@@ -281,8 +286,10 @@ class dpnp_sort_c_kernel;
281286
template <typename _DataType>
282287
void dpnp_sort_c(void* array1_in, void* result1, size_t size)
283288
{
284-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
285-
_DataType* result = reinterpret_cast<_DataType*>(result1);
289+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size, true);
290+
DPNPC_ptr_adapter<_DataType> result1_ptr(result1, size, true, true);
291+
_DataType* array_1 = input1_ptr.get_ptr();
292+
_DataType* result = result1_ptr.get_ptr();
286293

287294
std::copy(array_1, array_1 + size, result);
288295

0 commit comments

Comments
 (0)