Skip to content

Commit 8306d99

Browse files
authored
Add Python wrappers for c.parallel radix_sort API (NVIDIA#4353)
1 parent a5bc5e6 commit 8306d99

File tree

9 files changed

+955
-8
lines changed

9 files changed

+955
-8
lines changed

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyi

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ctypes
2-
from typing import Any
2+
from typing import Any, Optional
33

44
from typing_extensions import Buffer
55

@@ -54,18 +54,26 @@ class Enumeration_IteratorKind:
5454
@property
5555
def ITERATOR(self) -> IntEnumerationMember: ...
5656

57+
class Enumeration_SortOrder:
58+
@property
59+
def ASCENDING(self) -> IntEnumerationMember: ...
60+
@property
61+
def DESCENDING(self) -> IntEnumerationMember: ...
62+
5763
TypeEnum: Enumeration_CCCLType
5864
OpKind: Enumeration_OpKind
5965
IteratorKind: Enumeration_IteratorKind
66+
SortOrder: Enumeration_SortOrder
6067

6168
def is_TypeEnum(obj) -> bool: ...
6269
def is_OpKind(obj) -> bool: ...
6370
def is_IteratorKind(obj) -> bool: ...
71+
def is_SortOrder(obj) -> bool: ...
6472

6573
class Op:
6674
def __init__(
6775
self,
68-
name: str = ...,
76+
name: Optional[str] = ...,
6977
operator_type: IntEnumerationMember = ...,
7078
ltoir=None,
7179
state=None,
@@ -264,13 +272,13 @@ class DeviceMergeSortBuildResult:
264272
d_in_keys: Iterator,
265273
d_in_items: Iterator,
266274
d_out_keys: Iterator,
267-
d_out_itemss: Iterator,
275+
d_out_items: Iterator,
268276
binary_op: Op,
269277
info: CommonData,
270278
) -> int: ...
271279
def compute(
272280
self,
273-
temp_storage_ptr: int,
281+
temp_storage_ptr: int | None,
274282
temp_storage_nbytes: int,
275283
d_in_keys: Iterator,
276284
d_in_items: Iterator,
@@ -299,7 +307,7 @@ class DeviceUniqueByKeyBuildResult:
299307
) -> int: ...
300308
def compute(
301309
self,
302-
temp_storage_ptr: int,
310+
temp_storage_ptr: int | None,
303311
temp_storage_nbytes: int,
304312
d_keys_in: Iterator,
305313
d_values_in: Iterator,
@@ -311,6 +319,29 @@ class DeviceUniqueByKeyBuildResult:
311319
stream,
312320
) -> tuple[int, int]: ...
313321

322+
# -----------------
323+
# DeviceRadixSort
324+
# -----------------
325+
326+
class DeviceRadixSortBuildResult:
327+
def __init__(self): ...
328+
def compute(
329+
self,
330+
temp_storage_ptr: int | None,
331+
temp_storage_nbytes: int,
332+
d_keys_in: Iterator,
333+
d_keys_out: Iterator,
334+
d_values_in: Iterator,
335+
d_values_out: Iterator,
336+
decomposer_op: Op,
337+
num_items: int,
338+
begin_bit: int,
339+
end_bit: int,
340+
is_overwrite_okay: bool,
341+
selector: int,
342+
stream,
343+
) -> tuple[int, int]: ...
344+
314345
# --------------------
315346
# DeviceUnaryTransform
316347
# --------------------

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyx

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ cdef extern from "cccl/c/types.h":
7777
cccl_type_info value_type
7878
void *state
7979

80+
ctypedef enum cccl_sort_order_t:
81+
CCCL_ASCENDING
82+
CCCL_DESCENDING
8083

8184
cdef void arg_type_check(
8285
str arg_name,
@@ -410,11 +413,47 @@ cdef class Enumeration_IteratorKind(IntEnumerationBase):
410413
def ITERATOR(self):
411414
return self._iterator
412415

416+
cdef class Enumeration_SortOrder(IntEnumerationBase):
417+
"Enumeration of sort orders (ascending or descending)"
418+
cdef IntEnumerationMember _ascending
419+
cdef IntEnumerationMember _descending
420+
421+
def __cinit__(self):
422+
self.enum_name = "SortOrder"
423+
self._ascending = self.make_ASCENDING()
424+
self._descending = self.make_DESCENDING()
425+
426+
cdef IntEnumerationMember make_ASCENDING(self):
427+
cdef str prop_name = "ASCENDING"
428+
return IntEnumerationMember(
429+
type(self),
430+
self.enum_name,
431+
prop_name,
432+
cccl_sort_order_t.CCCL_ASCENDING
433+
)
434+
435+
cdef IntEnumerationMember make_DESCENDING(self):
436+
cdef str prop_name = "DESCENDING"
437+
return IntEnumerationMember(
438+
type(self),
439+
self.enum_name,
440+
prop_name,
441+
cccl_sort_order_t.CCCL_DESCENDING
442+
)
443+
444+
@property
445+
def ASCENDING(self):
446+
return self._ascending
447+
448+
@property
449+
def DESCENDING(self):
450+
return self._descending
451+
413452

414453
TypeEnum = Enumeration_CCCLType()
415454
OpKind = Enumeration_OpKind()
416455
IteratorKind = Enumeration_IteratorKind()
417-
456+
SortOrder = Enumeration_SortOrder()
418457

419458
cpdef bint is_TypeEnum(IntEnumerationMember attr):
420459
"Return True if attribute is a member of TypeEnum enumeration"
@@ -430,6 +469,10 @@ cpdef bint is_IteratorKind(IntEnumerationMember attr):
430469
"Return True if attribute is a member of IteratorKind enumeration"
431470
return attr.parent_class is Enumeration_IteratorKind
432471

472+
cpdef bint is_SortOrder(IntEnumerationMember attr):
473+
"Return True if attribute is a member of SortOrder enumeration"
474+
return attr.parent_class is Enumeration_SortOrder
475+
433476

434477
cdef void _validate_alignment(int alignment) except *:
435478
cdef uint32_t val
@@ -1621,12 +1664,147 @@ cdef class DeviceUniqueByKeyBuildResult:
16211664
<uint64_t>num_items,
16221665
c_stream
16231666
)
1667+
16241668
if status != 0:
16251669
raise RuntimeError(
16261670
f"Failed executing unique_by_key, error code: {status}"
16271671
)
16281672
return storage_sz
16291673

1674+
# -----------------
1675+
# DeviceRadixSort
1676+
# -----------------
1677+
1678+
cdef extern from "cccl/c/radix_sort.h":
1679+
cdef struct cccl_device_radix_sort_build_result_t 'cccl_device_radix_sort_build_result_t':
1680+
pass
1681+
1682+
cdef CUresult cccl_device_radix_sort_build(
1683+
cccl_device_radix_sort_build_result_t *build_ptr,
1684+
cccl_sort_order_t sort_order,
1685+
cccl_iterator_t d_keys_in,
1686+
cccl_iterator_t d_values_in,
1687+
cccl_op_t decomposer,
1688+
const char* decomposer_return_type,
1689+
int, int, const char *, const char *, const char *, const char *
1690+
) nogil
1691+
1692+
cdef CUresult cccl_device_radix_sort(
1693+
cccl_device_radix_sort_build_result_t build,
1694+
void *d_storage_ptr,
1695+
size_t *d_storage_nbytes,
1696+
cccl_iterator_t d_keys_in,
1697+
cccl_iterator_t d_keys_out,
1698+
cccl_iterator_t d_values_in,
1699+
cccl_iterator_t d_values_out,
1700+
cccl_op_t decomposer,
1701+
size_t num_items,
1702+
int begin_bit,
1703+
int end_bit,
1704+
bint is_overwrite_okay,
1705+
int* selector,
1706+
CUstream stream
1707+
) nogil
1708+
1709+
cdef CUresult cccl_device_radix_sort_cleanup(
1710+
cccl_device_radix_sort_build_result_t *build_ptr,
1711+
) nogil
1712+
1713+
1714+
cdef class DeviceRadixSortBuildResult:
1715+
cdef cccl_device_radix_sort_build_result_t build_data
1716+
1717+
def __dealloc__(DeviceRadixSortBuildResult self):
1718+
cdef CUresult status = -1
1719+
with nogil:
1720+
status = cccl_device_radix_sort_cleanup(&self.build_data)
1721+
if (status != 0):
1722+
print(f"Return code {status} encountered during radix_sort result cleanup")
1723+
1724+
def __cinit__(
1725+
DeviceRadixSortBuildResult self,
1726+
cccl_sort_order_t order,
1727+
Iterator d_keys_in,
1728+
Iterator d_values_in,
1729+
Op decomposer_op,
1730+
const char* decomposer_return_type,
1731+
CommonData common_data
1732+
):
1733+
cdef CUresult status = -1
1734+
cdef int cc_major = common_data.get_cc_major()
1735+
cdef int cc_minor = common_data.get_cc_minor()
1736+
cdef const char *cub_path = common_data.cub_path_get_c_str()
1737+
cdef const char *thrust_path = common_data.thrust_path_get_c_str()
1738+
cdef const char *libcudacxx_path = common_data.libcudacxx_path_get_c_str()
1739+
cdef const char *ctk_path = common_data.ctk_path_get_c_str()
1740+
1741+
memset(&self.build_data, 0, sizeof(cccl_device_radix_sort_build_result_t))
1742+
with nogil:
1743+
status = cccl_device_radix_sort_build(
1744+
&self.build_data,
1745+
order,
1746+
d_keys_in.iter_data,
1747+
d_values_in.iter_data,
1748+
decomposer_op.op_data,
1749+
decomposer_return_type,
1750+
cc_major,
1751+
cc_minor,
1752+
cub_path,
1753+
thrust_path,
1754+
libcudacxx_path,
1755+
ctk_path,
1756+
)
1757+
if status != 0:
1758+
raise RuntimeError(
1759+
f"Failed building radix_sort, error code: {status}"
1760+
)
1761+
1762+
cpdef tuple compute(
1763+
DeviceRadixSortBuildResult self,
1764+
temp_storage_ptr,
1765+
temp_storage_bytes,
1766+
Iterator d_keys_in,
1767+
Iterator d_keys_out,
1768+
Iterator d_values_in,
1769+
Iterator d_values_out,
1770+
Op decomposer_op,
1771+
size_t num_items,
1772+
int begin_bit,
1773+
int end_bit,
1774+
bint is_overwrite_okay,
1775+
selector,
1776+
stream
1777+
):
1778+
cdef CUresult status = -1
1779+
cdef void *storage_ptr = (<void *><size_t>temp_storage_ptr) if temp_storage_ptr else NULL
1780+
cdef size_t storage_sz = <size_t>temp_storage_bytes
1781+
cdef int selector_int = <int>selector
1782+
cdef CUstream c_stream = <CUstream><size_t>(stream) if stream else NULL
1783+
1784+
with nogil:
1785+
status = cccl_device_radix_sort(
1786+
self.build_data,
1787+
storage_ptr,
1788+
&storage_sz,
1789+
d_keys_in.iter_data,
1790+
d_keys_out.iter_data,
1791+
d_values_in.iter_data,
1792+
d_values_out.iter_data,
1793+
decomposer_op.op_data,
1794+
<uint64_t>num_items,
1795+
begin_bit,
1796+
end_bit,
1797+
is_overwrite_okay,
1798+
&selector_int,
1799+
c_stream
1800+
)
1801+
1802+
if status != 0:
1803+
raise RuntimeError(
1804+
f"Failed executing ascending radix_sort, error code: {status}"
1805+
)
1806+
return <object>storage_sz, <object>selector_int
1807+
16301808

16311809
def _get_cubin(self):
16321810
return self.build_data.cubin[:self.build_data.cubin_size]

python/cuda_parallel/cuda/parallel/experimental/_cccl_interop.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,16 @@ def to_cccl_value(array_or_struct: np.ndarray | GpuStruct) -> Value:
178178
return to_cccl_value(array_or_struct._data)
179179

180180

181-
def to_cccl_op(op: Callable, sig) -> Op:
181+
def to_cccl_op(op: Callable | None, sig) -> Op:
182+
if op is None:
183+
return Op(
184+
operator_type=OpKind.STATELESS,
185+
name=None,
186+
ltoir=None,
187+
state_alignment=1,
188+
state=None,
189+
)
190+
182191
ltoir, _ = cuda.compile(op, sig=sig, output="ltoir")
183192
return Op(
184193
operator_type=OpKind.STATELESS,

python/cuda_parallel/cuda/parallel/experimental/algorithms/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

66
from ._merge_sort import merge_sort as merge_sort
7+
from ._radix_sort import DoubleBuffer, SortOrder
8+
from ._radix_sort import radix_sort as radix_sort
79
from ._reduce import reduce_into as reduce_into
810
from ._scan import exclusive_scan as exclusive_scan
911
from ._scan import inclusive_scan as inclusive_scan
@@ -18,6 +20,9 @@
1820
"inclusive_scan",
1921
"segmented_reduce",
2022
"unique_by_key",
23+
"radix_sort",
24+
"DoubleBuffer",
25+
"SortOrder",
2126
"binary_transform",
2227
"unary_transform",
2328
]

python/cuda_parallel/cuda/parallel/experimental/algorithms/_merge_sort.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def __init__(
7272
d_out_items: DeviceArrayLike | None,
7373
op: Callable,
7474
):
75-
assert (d_in_items is None) == (d_out_items is None)
75+
present_in_values = d_in_items is not None
76+
present_out_values = d_out_items is not None
77+
assert present_in_values == present_out_values
7678

7779
self.d_in_keys_cccl = cccl.to_cccl_iter(d_in_keys)
7880
self.d_in_items_cccl = cccl.to_cccl_iter(d_in_items)

0 commit comments

Comments
 (0)