@@ -77,6 +77,9 @@ cdef extern from "cccl/c/types.h":
77
77
cccl_type_info value_type
78
78
void * state
79
79
80
+ ctypedef enum cccl_sort_order_t:
81
+ CCCL_ASCENDING
82
+ CCCL_DESCENDING
80
83
81
84
cdef void arg_type_check(
82
85
str arg_name,
@@ -410,11 +413,47 @@ cdef class Enumeration_IteratorKind(IntEnumerationBase):
410
413
def ITERATOR (self ):
411
414
return self ._iterator
412
415
416
+ cdef class Enumeration_SortOrder(IntEnumerationBase):
417
+ " Enumeration of sort orders (ascending or descending)"
418
+ cdef IntEnumerationMember _ascending
419
+ cdef IntEnumerationMember _descending
420
+
421
+ def __cinit__ (self ):
422
+ self .enum_name = " SortOrder"
423
+ self ._ascending = self .make_ASCENDING()
424
+ self ._descending = self .make_DESCENDING()
425
+
426
+ cdef IntEnumerationMember make_ASCENDING(self ):
427
+ cdef str prop_name = " ASCENDING"
428
+ return IntEnumerationMember(
429
+ type (self ),
430
+ self .enum_name,
431
+ prop_name,
432
+ cccl_sort_order_t.CCCL_ASCENDING
433
+ )
434
+
435
+ cdef IntEnumerationMember make_DESCENDING(self ):
436
+ cdef str prop_name = " DESCENDING"
437
+ return IntEnumerationMember(
438
+ type (self ),
439
+ self .enum_name,
440
+ prop_name,
441
+ cccl_sort_order_t.CCCL_DESCENDING
442
+ )
443
+
444
+ @property
445
+ def ASCENDING (self ):
446
+ return self ._ascending
447
+
448
+ @property
449
+ def DESCENDING (self ):
450
+ return self ._descending
451
+
413
452
414
453
TypeEnum = Enumeration_CCCLType()
415
454
OpKind = Enumeration_OpKind()
416
455
IteratorKind = Enumeration_IteratorKind()
417
-
456
+ SortOrder = Enumeration_SortOrder()
418
457
419
458
cpdef bint is_TypeEnum(IntEnumerationMember attr):
420
459
" Return True if attribute is a member of TypeEnum enumeration"
@@ -430,6 +469,10 @@ cpdef bint is_IteratorKind(IntEnumerationMember attr):
430
469
" Return True if attribute is a member of IteratorKind enumeration"
431
470
return attr.parent_class is Enumeration_IteratorKind
432
471
472
+ cpdef bint is_SortOrder(IntEnumerationMember attr):
473
+ " Return True if attribute is a member of SortOrder enumeration"
474
+ return attr.parent_class is Enumeration_SortOrder
475
+
433
476
434
477
cdef void _validate_alignment(int alignment) except * :
435
478
cdef uint32_t val
@@ -1621,12 +1664,147 @@ cdef class DeviceUniqueByKeyBuildResult:
1621
1664
< uint64_t> num_items,
1622
1665
c_stream
1623
1666
)
1667
+
1624
1668
if status != 0 :
1625
1669
raise RuntimeError (
1626
1670
f" Failed executing unique_by_key, error code: {status}"
1627
1671
)
1628
1672
return storage_sz
1629
1673
1674
+ # -----------------
1675
+ # DeviceRadixSort
1676
+ # -----------------
1677
+
1678
+ cdef extern from " cccl/c/radix_sort.h" :
1679
+ cdef struct cccl_device_radix_sort_build_result_t ' cccl_device_radix_sort_build_result_t' :
1680
+ pass
1681
+
1682
+ cdef CUresult cccl_device_radix_sort_build(
1683
+ cccl_device_radix_sort_build_result_t * build_ptr,
1684
+ cccl_sort_order_t sort_order,
1685
+ cccl_iterator_t d_keys_in,
1686
+ cccl_iterator_t d_values_in,
1687
+ cccl_op_t decomposer,
1688
+ const char * decomposer_return_type,
1689
+ int , int , const char * , const char * , const char * , const char *
1690
+ ) nogil
1691
+
1692
+ cdef CUresult cccl_device_radix_sort(
1693
+ cccl_device_radix_sort_build_result_t build,
1694
+ void * d_storage_ptr,
1695
+ size_t * d_storage_nbytes,
1696
+ cccl_iterator_t d_keys_in,
1697
+ cccl_iterator_t d_keys_out,
1698
+ cccl_iterator_t d_values_in,
1699
+ cccl_iterator_t d_values_out,
1700
+ cccl_op_t decomposer,
1701
+ size_t num_items,
1702
+ int begin_bit,
1703
+ int end_bit,
1704
+ bint is_overwrite_okay,
1705
+ int * selector,
1706
+ CUstream stream
1707
+ ) nogil
1708
+
1709
+ cdef CUresult cccl_device_radix_sort_cleanup(
1710
+ cccl_device_radix_sort_build_result_t * build_ptr,
1711
+ ) nogil
1712
+
1713
+
1714
+ cdef class DeviceRadixSortBuildResult:
1715
+ cdef cccl_device_radix_sort_build_result_t build_data
1716
+
1717
+ def __dealloc__ (DeviceRadixSortBuildResult self ):
1718
+ cdef CUresult status = - 1
1719
+ with nogil:
1720
+ status = cccl_device_radix_sort_cleanup(& self .build_data)
1721
+ if (status != 0 ):
1722
+ print (f" Return code {status} encountered during radix_sort result cleanup" )
1723
+
1724
+ def __cinit__ (
1725
+ DeviceRadixSortBuildResult self ,
1726
+ cccl_sort_order_t order ,
1727
+ Iterator d_keys_in ,
1728
+ Iterator d_values_in ,
1729
+ Op decomposer_op ,
1730
+ const char* decomposer_return_type ,
1731
+ CommonData common_data
1732
+ ):
1733
+ cdef CUresult status = - 1
1734
+ cdef int cc_major = common_data.get_cc_major()
1735
+ cdef int cc_minor = common_data.get_cc_minor()
1736
+ cdef const char * cub_path = common_data.cub_path_get_c_str()
1737
+ cdef const char * thrust_path = common_data.thrust_path_get_c_str()
1738
+ cdef const char * libcudacxx_path = common_data.libcudacxx_path_get_c_str()
1739
+ cdef const char * ctk_path = common_data.ctk_path_get_c_str()
1740
+
1741
+ memset(& self .build_data, 0 , sizeof(cccl_device_radix_sort_build_result_t))
1742
+ with nogil:
1743
+ status = cccl_device_radix_sort_build(
1744
+ & self .build_data,
1745
+ order,
1746
+ d_keys_in.iter_data,
1747
+ d_values_in.iter_data,
1748
+ decomposer_op.op_data,
1749
+ decomposer_return_type,
1750
+ cc_major,
1751
+ cc_minor,
1752
+ cub_path,
1753
+ thrust_path,
1754
+ libcudacxx_path,
1755
+ ctk_path,
1756
+ )
1757
+ if status != 0 :
1758
+ raise RuntimeError (
1759
+ f" Failed building radix_sort, error code: {status}"
1760
+ )
1761
+
1762
+ cpdef tuple compute(
1763
+ DeviceRadixSortBuildResult self ,
1764
+ temp_storage_ptr,
1765
+ temp_storage_bytes,
1766
+ Iterator d_keys_in,
1767
+ Iterator d_keys_out,
1768
+ Iterator d_values_in,
1769
+ Iterator d_values_out,
1770
+ Op decomposer_op,
1771
+ size_t num_items,
1772
+ int begin_bit,
1773
+ int end_bit,
1774
+ bint is_overwrite_okay,
1775
+ selector,
1776
+ stream
1777
+ ):
1778
+ cdef CUresult status = - 1
1779
+ cdef void * storage_ptr = (< void * >< size_t> temp_storage_ptr) if temp_storage_ptr else NULL
1780
+ cdef size_t storage_sz = < size_t> temp_storage_bytes
1781
+ cdef int selector_int = < int > selector
1782
+ cdef CUstream c_stream = < CUstream>< size_t> (stream) if stream else NULL
1783
+
1784
+ with nogil:
1785
+ status = cccl_device_radix_sort(
1786
+ self .build_data,
1787
+ storage_ptr,
1788
+ & storage_sz,
1789
+ d_keys_in.iter_data,
1790
+ d_keys_out.iter_data,
1791
+ d_values_in.iter_data,
1792
+ d_values_out.iter_data,
1793
+ decomposer_op.op_data,
1794
+ < uint64_t> num_items,
1795
+ begin_bit,
1796
+ end_bit,
1797
+ is_overwrite_okay,
1798
+ & selector_int,
1799
+ c_stream
1800
+ )
1801
+
1802
+ if status != 0 :
1803
+ raise RuntimeError (
1804
+ f" Failed executing ascending radix_sort, error code: {status}"
1805
+ )
1806
+ return < object > storage_sz, < object > selector_int
1807
+
1630
1808
1631
1809
def _get_cubin (self ):
1632
1810
return self .build_data.cubin[:self .build_data.cubin_size]
0 commit comments