Skip to content

Commit a152e41

Browse files
authored
Fix partition (#675)
* fix partition
1 parent c181e8c commit a152e41

File tree

3 files changed

+60
-21
lines changed

3 files changed

+60
-21
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,15 @@ INP_DLLEXPORT void dpnp_sum_c(void* result_out,
328328
* @ingroup BACKEND_API
329329
* @brief Place of array elements
330330
*
331-
* @param [in] sort_array Input sorted array.
331+
* @param [in] array Input array.
332+
* @param [in] array2 Copy input array.
332333
* @param [out] result Result array.
333334
* @param [in] kth Element index to partition by.
334335
* @param [in] shape Shape of input array.
335336
* @param [in] ndim Number of elements in shape.
336337
*/
337338
template <typename _DataType>
338-
INP_DLLEXPORT void dpnp_partition_c(const void* sort_array, void* result, const size_t kth, const size_t* shape, const size_t ndim);
339+
INP_DLLEXPORT void dpnp_partition_c(void* array, void* array2, void* result, const size_t kth, const size_t* shape, const size_t ndim);
339340

340341
/**
341342
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,54 +87,92 @@ template <typename _DataType>
8787
class dpnp_partition_c_kernel;
8888

8989
template <typename _DataType>
90-
void dpnp_partition_c(const void* sort_array1_in, void* result1, const size_t kth, const size_t* shape, const size_t ndim)
90+
void dpnp_partition_c(void* array1_in, void* array2_in, void* result1, const size_t kth, const size_t* shape_, const size_t ndim)
9191
{
92+
_DataType* arr = reinterpret_cast<_DataType*>(array1_in);
93+
_DataType* arr2 = reinterpret_cast<_DataType*>(array2_in);
94+
_DataType* result = reinterpret_cast<_DataType*>(result1);
9295

93-
cl::sycl::event event;
96+
if ((arr == nullptr) || (result == nullptr))
97+
{
98+
return;
99+
}
94100

95-
const _DataType* sort_arr = reinterpret_cast<const _DataType*>(sort_array1_in);
96-
_DataType* result = reinterpret_cast<_DataType*>(result1);
101+
if (ndim < 1)
102+
{
103+
return;
104+
}
97105

98-
size_t size_ = 1;
99-
for (size_t i = 0; i < ndim - 1; ++i)
106+
size_t size = 1;
107+
for (size_t i = 0; i < ndim; ++i)
100108
{
101-
size_ *= shape[i];
109+
size *= shape_[i];
102110
}
103111

112+
size_t size_ = size/shape_[ndim-1];
113+
104114
if (size_ == 0)
105115
{
106116
return;
107117
}
108118

119+
auto arr_to_result_event = DPNP_QUEUE.memcpy(result, arr, size * sizeof(_DataType));
120+
arr_to_result_event.wait();
121+
122+
for (size_t i = 0; i < size_; ++i)
123+
{
124+
size_t ind_begin = i * shape_[ndim-1];
125+
size_t ind_end = (i + 1) * shape_[ndim-1] - 1;
126+
127+
_DataType matrix[shape_[ndim-1]];
128+
for (size_t j = ind_begin; j < ind_end + 1; ++j)
129+
{
130+
size_t ind = j - ind_begin;
131+
matrix[ind] = arr2[j];
132+
}
133+
std::partial_sort(matrix, matrix + shape_[ndim-1], matrix + shape_[ndim-1]);
134+
for (size_t j = ind_begin; j < ind_end + 1; ++j)
135+
{
136+
size_t ind = j - ind_begin;
137+
arr2[j] = matrix[ind];
138+
}
139+
}
140+
141+
    size_t* shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
142+
auto memcpy_event = DPNP_QUEUE.memcpy(shape, shape_, ndim * sizeof(size_t));
143+
144+
memcpy_event.wait();
145+
109146
cl::sycl::range<2> gws(size_, kth+1);
110147
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
111148
size_t j = global_id[0];
112149
size_t k = global_id[1];
113150

114-
_DataType val = sort_arr[j * shape[ndim - 1] + k];
151+
_DataType val = arr2[j * shape[ndim - 1] + k];
115152

116-
size_t ind = j * shape[ndim - 1] + k;
117153
for (size_t i = 0; i < shape[ndim - 1]; ++i)
118154
{
119155
if (result[j * shape[ndim - 1] + i] == val)
120156
{
121-
ind = j * shape[ndim - 1] + i;
122-
break;
157+
_DataType change_val1 = result[j * shape[ndim - 1] + i];
158+
_DataType change_val2 = result[j * shape[ndim - 1] + k];
159+
result[j * shape[ndim - 1] + k] = change_val1;
160+
result[j * shape[ndim - 1] + i] = change_val2;
123161
}
124162
}
125163

126-
_DataType change_val = result[j * shape[ndim - 1] + k];
127-
result[j * shape[ndim - 1] + k] = val;
128-
result[ind] = change_val;
129164
};
130165

131166
auto kernel_func = [&](cl::sycl::handler& cgh) {
167+
cgh.depends_on({memcpy_event});
132168
cgh.parallel_for<class dpnp_partition_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
133169
};
134170

135-
event = DPNP_QUEUE.submit(kernel_func);
171+
auto event = DPNP_QUEUE.submit(kernel_func);
136172

137173
event.wait();
174+
175+
    dpnp_memory_free_c(shape);
138176
}
139177

140178
template <typename _DataType>

dpnp/dpnp_algo/dpnp_algo_sorting.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ __all__ += [
4444
]
4545

4646

47-
ctypedef void(*fptr_dpnp_partition_t)(const void * , void * , const size_t , const size_t * , const size_t)
47+
ctypedef void(*fptr_dpnp_partition_t)(void * , void * , void * , const size_t , const size_t * , const size_t)
4848

4949

5050
cpdef dparray dpnp_argsort(dparray in_array1):
@@ -58,12 +58,12 @@ cpdef dparray dpnp_partition(dparray arr, int kth, axis=-1, kind='introselect',
5858
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PARTITION, param1_type, param1_type)
5959

6060
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
61-
cdef dparray result = dpnp.copy(arr)
62-
cdef dparray sort_arr = dpnp.sort(arr)
61+
cdef dparray arr2 = dpnp.copy(arr)
62+
cdef dparray result = dparray(arr.shape, dtype=result_type)
6363

6464
cdef fptr_dpnp_partition_t func = <fptr_dpnp_partition_t > kernel_data.ptr
6565

66-
func(sort_arr.get_data(), result.get_data(), kth_, < size_t * > arr._dparray_shape.data(), arr.ndim)
66+
func(arr.get_data(), arr2.get_data(), result.get_data(), kth_, < size_t * > arr._dparray_shape.data(), arr.ndim)
6767

6868
return result
6969

0 commit comments

Comments
 (0)