Skip to content

Commit 7ab68ea

Browse files
authored
Partition (#653)
* add dpnp.partition
1 parent 2fe40b7 commit 7ab68ea

File tree

9 files changed

+175
-30
lines changed

9 files changed

+175
-30
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,19 @@ INP_DLLEXPORT void dpnp_sum_c(void* result_out,
291291
const void* initial,
292292
const long* where);
293293

294+
/**
295+
* @ingroup BACKEND_API
296+
* @brief Place of array elements
297+
*
298+
* @param [in] sort_array Input sorted array.
299+
* @param [out] result Result array.
300+
* @param [in] kth Element index to partition by.
301+
* @param [in] shape Shape of input array.
302+
* @param [in] ndim Number of elements in shape.
303+
*/
304+
template <typename _DataType>
305+
INP_DLLEXPORT void dpnp_partition_c(const void* sort_array, void* result, const size_t kth, const size_t* shape, const size_t ndim);
306+
294307
/**
295308
* @ingroup BACKEND_API
296309
* @brief Place of array elements

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ enum class DPNPFuncName : size_t
137137
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() implementation */
138138
DPNP_FN_ONES, /**< Used in numpy.ones() implementation */
139139
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() implementation */
140+
DPNP_FN_PARTITION, /**< Used in numpy.partition() implementation */
140141
DPNP_FN_PLACE, /**< Used in numpy.place() implementation */
141142
DPNP_FN_POWER, /**< Used in numpy.power() implementation */
142143
DPNP_FN_PROD, /**< Used in numpy.prod() implementation */

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,60 @@ struct _sort_less
8383
}
8484
};
8585

86+
template <typename _DataType>
87+
class dpnp_partition_c_kernel;
88+
89+
template <typename _DataType>
90+
void dpnp_partition_c(const void* sort_array1_in, void* result1, const size_t kth, const size_t* shape, const size_t ndim)
91+
{
92+
93+
cl::sycl::event event;
94+
95+
const _DataType* sort_arr = reinterpret_cast<const _DataType*>(sort_array1_in);
96+
_DataType* result = reinterpret_cast<_DataType*>(result1);
97+
98+
size_t size_ = 1;
99+
for (size_t i = 0; i < ndim - 1; ++i)
100+
{
101+
size_ *= shape[i];
102+
}
103+
104+
if (size_ == 0)
105+
{
106+
return;
107+
}
108+
109+
cl::sycl::range<2> gws(size_, kth+1);
110+
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
111+
size_t j = global_id[0];
112+
size_t k = global_id[1];
113+
114+
_DataType val = sort_arr[j * shape[ndim - 1] + k];
115+
116+
size_t ind = j * shape[ndim - 1] + k;
117+
for (size_t i = 0; i < shape[ndim - 1]; ++i)
118+
{
119+
if (result[j * shape[ndim - 1] + i] == val)
120+
{
121+
ind = j * shape[ndim - 1] + i;
122+
break;
123+
}
124+
}
125+
126+
_DataType change_val = result[j * shape[ndim - 1] + k];
127+
result[j * shape[ndim - 1] + k] = val;
128+
result[ind] = change_val;
129+
};
130+
131+
auto kernel_func = [&](cl::sycl::handler& cgh) {
132+
cgh.parallel_for<class dpnp_partition_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
133+
};
134+
135+
event = DPNP_QUEUE.submit(kernel_func);
136+
137+
event.wait();
138+
}
139+
86140
template <typename _DataType>
87141
class dpnp_sort_c_kernel;
88142

@@ -110,6 +164,11 @@ void func_map_init_sorting(func_map_t& fmap)
110164
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_FLT][eft_FLT] = {eft_LNG, (void*)dpnp_argsort_c<float, long>};
111165
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_DBL][eft_DBL] = {eft_LNG, (void*)dpnp_argsort_c<double, long>};
112166

167+
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_partition_c<int>};
168+
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_partition_c<long>};
169+
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_c<float>};
170+
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_c<double>};
171+
113172
fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_sort_c<int>};
114173
fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_sort_c<long>};
115174
fmap[DPNPFuncName::DPNP_FN_SORT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_sort_c<float>};

dpnp/dparray.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,20 @@ cdef class dparray:
951951
"""
952952
return argsort(self, axis, kind, order)
953953
954+
def partition(self, kth, axis=-1, kind='introselect', order=None):
955+
"""
956+
Return a partitioned copy of an array.
957+
For full documentation refer to :obj:`numpy.partition`.
958+
959+
Limitations
960+
-----------
961+
Input array is supported as :obj:`dpnp.ndarray`.
962+
Input kth is supported as :obj:`int`.
963+
Parameters ``axis``, ``kind`` and ``order`` are supported only with default values.
964+
"""
965+
966+
return partition(self, kth, axis, kind, order)
967+
954968
def sort(self, axis=-1, kind=None, order=None):
955969
"""
956970
Sort the array

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
110110
DPNP_FN_NONZERO
111111
DPNP_FN_ONES
112112
DPNP_FN_ONES_LIKE
113+
DPNP_FN_PARTITION
113114
DPNP_FN_PLACE
114115
DPNP_FN_POWER
115116
DPNP_FN_PROD

dpnp/dpnp_algo/dpnp_algo_sorting.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,34 @@ from dpnp.dpnp_utils cimport *
3939

4040
__all__ += [
4141
"dpnp_argsort",
42+
"dpnp_partition",
4243
"dpnp_sort"
4344
]
4445

4546

47+
ctypedef void(*fptr_dpnp_partition_t)(const void * , void * , const size_t , const size_t * , const size_t)
48+
49+
4650
cpdef dparray dpnp_argsort(dparray in_array1):
4751
return call_fptr_1in_1out(DPNP_FN_ARGSORT, in_array1, in_array1.shape)
4852

4953

54+
cpdef dparray dpnp_partition(dparray arr, int kth, axis=-1, kind='introselect', order=None):
55+
cdef size_t kth_ = kth if kth >= 0 else (arr.ndim + kth)
56+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
57+
58+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PARTITION, param1_type, param1_type)
59+
60+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
61+
cdef dparray result = dpnp.copy(arr)
62+
cdef dparray sort_arr = dpnp.sort(arr)
63+
64+
cdef fptr_dpnp_partition_t func = <fptr_dpnp_partition_t > kernel_data.ptr
65+
66+
func(sort_arr.get_data(), result.get_data(), kth_, < size_t * > arr._dparray_shape.data(), arr.ndim)
67+
68+
return result
69+
70+
5071
cpdef dparray dpnp_sort(dparray in_array1):
5172
return call_fptr_1in_1out(DPNP_FN_SORT, in_array1, in_array1.shape)

dpnp/dpnp_iface_sorting.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,19 @@
4040
"""
4141

4242

43+
4344
import numpy
4445

4546
from dpnp.dpnp_algo import *
4647
from dpnp.dparray import dparray
4748
from dpnp.dpnp_utils import *
4849

50+
import dpnp
51+
4952

5053
__all__ = [
5154
'argsort',
55+
'partition',
5256
'sort'
5357
]
5458

@@ -101,6 +105,36 @@ def argsort(in_array1, axis=-1, kind=None, order=None):
101105
return numpy.argsort(in_array1, axis, kind, order)
102106

103107

108+
def partition(arr, kth, axis=-1, kind='introselect', order=None):
109+
"""
110+
Return a partitioned copy of an array.
111+
For full documentation refer to :obj:`numpy.partition`.
112+
113+
Limitations
114+
-----------
115+
Input array is supported as :obj:`dpnp.ndarray`.
116+
Input kth is supported as :obj:`int`.
117+
Parameters ``axis``, ``kind`` and ``order`` are supported only with default values.
118+
"""
119+
if not use_origin_backend():
120+
if not isinstance(arr, dparray):
121+
pass
122+
elif not isinstance(kth, int):
123+
pass
124+
elif kth >= arr.shape[arr.ndim - 1] or arr.ndim + kth < 0:
125+
pass
126+
elif axis != -1:
127+
pass
128+
elif kind != 'introselect':
129+
pass
130+
elif order is not None:
131+
pass
132+
else:
133+
return dpnp_partition(arr, kth, axis, kind, order)
134+
135+
return call_origin(numpy.partition, arr, kth, axis, kind, order)
136+
137+
104138
def sort(x1, **kwargs):
105139
"""
106140
Return a sorted copy of an array.

tests/skipped_tests.tbl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,12 +1452,6 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestArgsort_param_1_{external
14521452
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgsort_param_1_{external=True}::test_nan1
14531453
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgsort_param_1_{external=True}::test_nan2
14541454
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_axis
1455-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_axis1
1456-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_axis2
1457-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_kth
1458-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_negative_axis1
1459-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_negative_axis2
1460-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_invalid_negative_kth
14611455
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_multi_dim
14621456
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_negative_axis
14631457
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_negative_kth
@@ -1466,48 +1460,24 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{extern
14661460
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_sequence_kth
14671461
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_0_{external=False, length=10}::test_partition_zero_dim
14681462
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_axis
1469-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_axis1
1470-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_axis2
1471-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_kth
1472-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_negative_axis1
1473-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_negative_axis2
1474-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_negative_kth
14751463
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_multi_dim
14761464
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_negative_axis
14771465
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_negative_kth
14781466
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_non_contiguous
14791467
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_one_dim
14801468
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_sequence_kth
14811469
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_zero_dim
1482-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_axis
1483-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_axis1
1484-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_axis2
1485-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_kth
1486-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_negative_axis1
1487-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_negative_axis2
1488-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_invalid_negative_kth
14891470
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_multi_dim
14901471
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_negative_axis
14911472
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_negative_kth
14921473
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_non_contiguous
1493-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_none_axis
14941474
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_one_dim
1495-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_sequence_kth
14961475
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_zero_dim
1497-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_axis
1498-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_axis1
1499-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_axis2
1500-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_kth
1501-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_negative_axis1
1502-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_negative_axis2
1503-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_invalid_negative_kth
15041476
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_multi_dim
15051477
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_negative_axis
15061478
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_negative_kth
15071479
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_non_contiguous
1508-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_none_axis
15091480
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_one_dim
1510-
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_sequence_kth
15111481
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_zero_dim
15121482
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
15131483
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1

tests/test_sort.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
import dpnp
4+
5+
import numpy
6+
7+
8+
@pytest.mark.parametrize("kth",
9+
[0, 1],
10+
ids=['0', '1'])
11+
@pytest.mark.parametrize("dtype",
12+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
13+
ids=['float64', 'float32', 'int64', 'int32'])
14+
@pytest.mark.parametrize("array",
15+
[[3, 4, 2, 1],
16+
[[1, 0], [3, 0]],
17+
[[3, 2], [1, 6]],
18+
[[4, 2, 3], [3, 4, 1]],
19+
[[[1, -3], [3, 0]], [[5, 2], [0, 1]], [[1, 0], [0, 1]]],
20+
[[[[8, 2], [3, 0]], [[5, 2], [0, 1]]], [[[1, 3], [3, 1]], [[5, 2], [0, 1]]]]],
21+
ids=['[3, 4, 2, 1]',
22+
'[[1, 0], [3, 0]]',
23+
'[[3, 2], [1, 6]]',
24+
'[[4, 2, 3], [3, 4, 1]]',
25+
'[[[1, -3], [3, 0]], [[5, 2], [0, 1]], [[1, 0], [0, 1]]]',
26+
'[[[[8, 2], [3, 0]], [[5, 2], [0, 1]]], [[[1, 3], [3, 1]], [[5, 2], [0, 1]]]]'])
27+
def test_partition(array, dtype, kth):
28+
a = numpy.array(array, dtype)
29+
ia = dpnp.array(array, dtype)
30+
expected = numpy.partition(a, kth)
31+
result = dpnp.partition(ia, kth)
32+
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)