Skip to content

Commit 6428610

Browse files
Checkpoint for implementing approach 2 in NVIDIAgh-4148
1 parent 49bb3a8 commit 6428610

File tree

5 files changed

+107
-7
lines changed

5 files changed

+107
-7
lines changed

c/parallel/include/cccl/c/types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <cccl/c/extern_c.h>
2424
#include <stddef.h>
25+
#include <stdint.h>
2526

2627
CCCL_C_EXTERN_C_BEGIN
2728

@@ -76,6 +77,14 @@ typedef struct cccl_value_t
7677
void* state;
7778
} cccl_value_t;
7879

80+
typedef union
81+
{
82+
int64_t signed_offset;
83+
uint64_t unsigned_offset;
84+
} cccl_increment_t;
85+
86+
typedef void (*cccl_host_op_fn_ptr_t)(void*, cccl_increment_t);
87+
7988
typedef struct cccl_iterator_t
8089
{
8190
size_t size;
@@ -84,6 +93,7 @@ typedef struct cccl_iterator_t
8493
cccl_op_t advance;
8594
cccl_op_t dereference;
8695
cccl_type_info value_type;
96+
cccl_host_op_fn_ptr_t host_advance;
8797
void* state;
8898
} cccl_iterator_t;
8999

c/parallel/src/segmented_reduce.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@ CUresult cccl_device_segmented_reduce(
406406

407407
auto exec_status = cub::DispatchSegmentedReduce<
408408
indirect_arg_t, // InputIteratorT
409-
indirect_arg_t, // OutputIteratorT
410-
indirect_arg_t, // BeginSegmentIteratorT
411-
indirect_arg_t, // EndSegmentIteratorT
409+
indirect_iterator_t, // OutputIteratorT
410+
indirect_iterator_t, // BeginSegmentIteratorT
411+
indirect_iterator_t, // EndSegmentIteratorT
412412
OffsetT, // OffsetT
413413
indirect_arg_t, // ReductionOpT
414414
indirect_arg_t, // InitT

c/parallel/src/util/indirect_arg.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,66 @@ struct indirect_arg_t
3333
return ptr;
3434
}
3535
};
36+
37+
struct indirect_iterator_t
38+
{
39+
void* ptr;
40+
size_t value_size;
41+
cccl_host_op_fn_ptr_t host_advance_fn_p;
42+
43+
indirect_iterator_t(cccl_iterator_t& it)
44+
: ptr{nullptr}
45+
, value_size{0}
46+
, host_advance_fn_p{nullptr}
47+
{
48+
if (it.type == cccl_iterator_kind_t::CCCL_POINTER)
49+
{
50+
value_size = it.value_type.size;
51+
ptr = &it.state;
52+
}
53+
else
54+
{
55+
ptr = it.state;
56+
host_advance_fn_p = it.host_advance;
57+
}
58+
}
59+
60+
void* operator&() const
61+
{
62+
return ptr;
63+
}
64+
65+
void operator+=(int64_t signed_offset)
66+
{
67+
if (value_size)
68+
{
69+
// CCCL_POINTER case
70+
ptr = reinterpret_cast<void*>(reinterpret_cast<char*>(ptr) + (signed_offset * value_size));
71+
}
72+
else
73+
{
74+
if (host_advance_fn_p)
75+
{
76+
cccl_increment_t incr{.signed_offset = signed_offset};
77+
(*host_advance_fn_p)(ptr, incr);
78+
}
79+
}
80+
}
81+
82+
void operator+=(uint64_t unsigned_offset)
83+
{
84+
if (value_size)
85+
{
86+
// CCCL_POINTER case
87+
ptr = reinterpret_cast<void*>(reinterpret_cast<char*>(ptr) + (unsigned_offset * value_size));
88+
}
89+
else
90+
{
91+
if (host_advance_fn_p)
92+
{
93+
cccl_increment_t incr{.unsigned_offset = unsigned_offset};
94+
(*host_advance_fn_p)(ptr, incr);
95+
}
96+
}
97+
}
98+
};

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class Iterator:
132132
dereference_fn: Op,
133133
value_type: TypeInfo,
134134
state=None,
135+
host_advance_fn=None,
135136
):
136137
pass
137138

@@ -287,7 +288,7 @@ class DeviceMergeSortBuildResult:
287288
num_items: int,
288289
binary_op: Op,
289290
stream,
290-
) -> tuple[int, int]: ...
291+
) -> int: ...
291292

292293
# -----------------
293294
# DeviceUniqueByKey
@@ -317,7 +318,7 @@ class DeviceUniqueByKeyBuildResult:
317318
binary_op: Op,
318319
num_items: int,
319320
stream,
320-
) -> tuple[int, int]: ...
321+
) -> int: ...
321322

322323
# -----------------
323324
# DeviceRadixSort

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyx

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# static type checker tools like mypy green-lights cuda.parallel
88

99
from libc.string cimport memset, memcpy
10-
from libc.stdint cimport uint8_t, uint32_t, uint64_t
10+
from libc.stdint cimport uint8_t, uint32_t, uint64_t, int64_t
1111
from cpython.bytes cimport PyBytes_FromStringAndSize
1212

1313
from cpython.buffer cimport (
@@ -68,13 +68,20 @@ cdef extern from "cccl/c/types.h":
6868
cccl_type_info type
6969
void *state
7070

71+
cdef union cccl_increment_t:
72+
int64_t signed_offset
73+
uint64_t unsigned_offset
74+
75+
ctypedef void (*cccl_host_op_fn_ptr_t)(void *, cccl_increment_t) nogil
76+
7177
cdef struct cccl_iterator_t:
7278
size_t size
7379
size_t alignment
7480
cccl_iterator_kind_t type
7581
cccl_op_t advance
7682
cccl_op_t dereference
7783
cccl_type_info value_type
84+
cccl_host_op_fn_ptr_t host_advance
7885
void *state
7986

8087
ctypedef enum cccl_sort_order_t:
@@ -799,10 +806,16 @@ cdef class IteratorState(StateBase):
799806
pass
800807

801808

809+
810+
cdef cccl_host_op_fn_ptr_t unbox_host_advance_fn(object host_fn_obj) except *:
811+
return <cccl_host_op_fn_ptr_t>NULL
812+
813+
802814
cdef class Iterator:
803815
cdef Op advance
804816
cdef Op dereference
805817
cdef object state_obj
818+
cdef object host_advance_obj
806819
cdef cccl_iterator_t iter_data
807820

808821
def __cinit__(self,
@@ -811,7 +824,8 @@ cdef class Iterator:
811824
Op advance_fn,
812825
Op dereference_fn,
813826
TypeInfo value_type,
814-
state = None
827+
state=None,
828+
host_advance_fn=None
815829
):
816830
cdef cccl_iterator_kind_t it_kind
817831
_validate_alignment(alignment)
@@ -836,6 +850,12 @@ cdef class Iterator:
836850
"Expect for Iterator of kind POINTER, state must have type Pointer or int, "
837851
f"got {type(state)}"
838852
)
853+
if host_advance_fn is not None:
854+
raise ValueError(
855+
"host_advance_fn must be set to None for iterators of kind POINTER"
856+
)
857+
self.iter_data.host_advance = NULL
858+
self.host_advance_obj = None
839859
elif it_kind == cccl_iterator_kind_t.CCCL_ITERATOR:
840860
if state is None:
841861
self.state_obj = None
@@ -850,6 +870,12 @@ cdef class Iterator:
850870
"For Iterator of kind ITERATOR, state must have type IteratorState, "
851871
f"got type {type(state)}"
852872
)
873+
if host_advance_fn is not None:
874+
self.iter_data.host_advance = unbox_host_advance_fn(host_advance_fn)
875+
self.host_advance_obj = host_advance_fn
876+
else:
877+
self.iter_data.host_advance = NULL
878+
self.host_advance_obj = None
853879
else: # pragma: no cover
854880
raise ValueError("Unrecognized iterator kind")
855881
self.advance = advance_fn

0 commit comments

Comments
 (0)