Skip to content

Commit 5c3a236

Browse files
Checkpoint for implementing approach 2 in NVIDIAgh-4148
1 parent a5bc5e6 commit 5c3a236

File tree

5 files changed

+107
-7
lines changed

5 files changed

+107
-7
lines changed

c/parallel/include/cccl/c/types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <cccl/c/extern_c.h>
2424
#include <stddef.h>
25+
#include <stdint.h>
2526

2627
CCCL_C_EXTERN_C_BEGIN
2728

@@ -76,6 +77,14 @@ typedef struct cccl_value_t
7677
void* state;
7778
} cccl_value_t;
7879

80+
typedef union
81+
{
82+
int64_t signed_offset;
83+
uint64_t unsigned_offset;
84+
} cccl_increment_t;
85+
86+
typedef void (*cccl_host_op_fn_ptr_t)(void*, cccl_increment_t);
87+
7988
typedef struct cccl_iterator_t
8089
{
8190
size_t size;
@@ -84,6 +93,7 @@ typedef struct cccl_iterator_t
8493
cccl_op_t advance;
8594
cccl_op_t dereference;
8695
cccl_type_info value_type;
96+
cccl_host_op_fn_ptr_t host_advance;
8797
void* state;
8898
} cccl_iterator_t;
8999

c/parallel/src/segmented_reduce.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@ CUresult cccl_device_segmented_reduce(
406406

407407
cub::DispatchSegmentedReduce<
408408
indirect_arg_t, // InputIteratorT
409-
indirect_arg_t, // OutputIteratorT
410-
indirect_arg_t, // BeginSegmentIteratorT
411-
indirect_arg_t, // EndSegmentIteratorT
409+
indirect_iterator_t, // OutputIteratorT
410+
indirect_iterator_t, // BeginSegmentIteratorT
411+
indirect_iterator_t, // EndSegmentIteratorT
412412
OffsetT, // OffsetT
413413
indirect_arg_t, // ReductionOpT
414414
indirect_arg_t, // InitT

c/parallel/src/util/indirect_arg.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,66 @@ struct indirect_arg_t
3333
return ptr;
3434
}
3535
};
36+
37+
struct indirect_iterator_t
38+
{
39+
void* ptr;
40+
size_t value_size;
41+
cccl_host_op_fn_ptr_t host_advance_fn_p;
42+
43+
indirect_iterator_t(cccl_iterator_t& it)
44+
: ptr{nullptr}
45+
, value_size{0}
46+
, host_advance_fn_p{nullptr}
47+
{
48+
if (it.type == cccl_iterator_kind_t::CCCL_POINTER)
49+
{
50+
value_size = it.value_type.size;
51+
ptr = &it.state;
52+
}
53+
else
54+
{
55+
ptr = it.state;
56+
host_advance_fn_p = it.host_advance;
57+
}
58+
}
59+
60+
void* operator&() const
61+
{
62+
return ptr;
63+
}
64+
65+
void operator+=(int64_t signed_offset)
66+
{
67+
if (value_size)
68+
{
69+
// CCCL_POINTER case
70+
ptr = reinterpret_cast<void*>(reinterpret_cast<char*>(ptr) + (signed_offset * value_size));
71+
}
72+
else
73+
{
74+
if (host_advance_fn_p)
75+
{
76+
cccl_increment_t incr{.signed_offset = signed_offset};
77+
(*host_advance_fn_p)(ptr, incr);
78+
}
79+
}
80+
}
81+
82+
void operator+=(uint64_t unsigned_offset)
83+
{
84+
if (value_size)
85+
{
86+
// CCCL_POINTER case
87+
ptr = reinterpret_cast<void*>(reinterpret_cast<char*>(ptr) + (unsigned_offset * value_size));
88+
}
89+
else
90+
{
91+
if (host_advance_fn_p)
92+
{
93+
cccl_increment_t incr{.unsigned_offset = unsigned_offset};
94+
(*host_advance_fn_p)(ptr, incr);
95+
}
96+
}
97+
}
98+
};

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class Iterator:
124124
dereference_fn: Op,
125125
value_type: TypeInfo,
126126
state=None,
127+
host_advance_fn=None,
127128
):
128129
pass
129130

@@ -279,7 +280,7 @@ class DeviceMergeSortBuildResult:
279280
num_items: int,
280281
binary_op: Op,
281282
stream,
282-
) -> tuple[int, int]: ...
283+
) -> int: ...
283284

284285
# -----------------
285286
# DeviceUniqueByKey
@@ -309,7 +310,7 @@ class DeviceUniqueByKeyBuildResult:
309310
binary_op: Op,
310311
num_items: int,
311312
stream,
312-
) -> tuple[int, int]: ...
313+
) -> int: ...
313314

314315
# --------------------
315316
# DeviceUnaryTransform

python/cuda_parallel/cuda/parallel/experimental/_bindings.pyx

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# static type checker tools like mypy green-lights cuda.parallel
88

99
from libc.string cimport memset, memcpy
10-
from libc.stdint cimport uint8_t, uint32_t, uint64_t
10+
from libc.stdint cimport uint8_t, uint32_t, uint64_t, int64_t
1111
from cpython.bytes cimport PyBytes_FromStringAndSize
1212

1313
from cpython.buffer cimport (
@@ -68,13 +68,20 @@ cdef extern from "cccl/c/types.h":
6868
cccl_type_info type
6969
void *state
7070

71+
cdef union cccl_increment_t:
72+
int64_t signed_offset
73+
uint64_t unsigned_offset
74+
75+
ctypedef void (*cccl_host_op_fn_ptr_t)(void *, cccl_increment_t) nogil
76+
7177
cdef struct cccl_iterator_t:
7278
size_t size
7379
size_t alignment
7480
cccl_iterator_kind_t type
7581
cccl_op_t advance
7682
cccl_op_t dereference
7783
cccl_type_info value_type
84+
cccl_host_op_fn_ptr_t host_advance
7885
void *state
7986

8087

@@ -756,10 +763,16 @@ cdef class IteratorState(StateBase):
756763
pass
757764

758765

766+
767+
cdef cccl_host_op_fn_ptr_t unbox_host_advance_fn(object host_fn_obj) except *:
768+
return <cccl_host_op_fn_ptr_t>NULL
769+
770+
759771
cdef class Iterator:
760772
cdef Op advance
761773
cdef Op dereference
762774
cdef object state_obj
775+
cdef object host_advance_obj
763776
cdef cccl_iterator_t iter_data
764777

765778
def __cinit__(self,
@@ -768,7 +781,8 @@ cdef class Iterator:
768781
Op advance_fn,
769782
Op dereference_fn,
770783
TypeInfo value_type,
771-
state = None
784+
state=None,
785+
host_advance_fn=None
772786
):
773787
cdef cccl_iterator_kind_t it_kind
774788
_validate_alignment(alignment)
@@ -793,6 +807,12 @@ cdef class Iterator:
793807
"Expect for Iterator of kind POINTER, state must have type Pointer or int, "
794808
f"got {type(state)}"
795809
)
810+
if host_advance_fn is not None:
811+
raise ValueError(
812+
"host_advance_fn must be set to None for iterators of kind POINTER"
813+
)
814+
self.iter_data.host_advance = NULL
815+
self.host_advance_obj = None
796816
elif it_kind == cccl_iterator_kind_t.CCCL_ITERATOR:
797817
if state is None:
798818
self.state_obj = None
@@ -807,6 +827,12 @@ cdef class Iterator:
807827
"For Iterator of kind ITERATOR, state must have type IteratorState, "
808828
f"got type {type(state)}"
809829
)
830+
if host_advance_fn is not None:
831+
self.iter_data.host_advance = unbox_host_advance_fn(host_advance_fn)
832+
self.host_advance_obj = host_advance_fn
833+
else:
834+
self.iter_data.host_advance = NULL
835+
self.host_advance_obj = None
810836
else: # pragma: no cover
811837
raise ValueError("Unrecognized iterator kind")
812838
self.advance = advance_fn

0 commit comments

Comments
 (0)