Skip to content

Commit 9ced57c

Browse files
authored
Use device functions that accept pointer arguments in ccc.cl and cuda.parallel (NVIDIA#4249)
* Change signature of old-style operators to use void* args * Change signature of new-style operators to use void* args * Change all c.parallel tests to define operators with void* signature * Update to_cccl_op() to wrap user-provided function appropriately * Update calls to `to_cccl_op()` to always have return type * Unskip previously skipped pytests as they now work * Enable sass verification for pytests * Also make stateful operators have the void* signature * Use voidptr instead of CPointer(int8) for void* arguments * pre-commit fix * Change merge_sort tuning to use 1 items per thread. Comment out test that failed due to too much shared memory. * Use void* arguments for all user-defined ops in tests * Address Python review comments * Add a TODO re: name mangling * Maybe we need known-first-party for `cuda.cccl` * Use `const void*` for input arguments. * Just inline `sizeof...(ArgTs)` * Fix transform test * Actually fix transform tests * Simplifications to operation.h * Remove unused helpers * Fix TargetCFuncPtr type definitions * Also const args for stateful ops --------- Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com>
1 parent 37c768d commit 9ced57c

28 files changed

+496
-143
lines changed

c/parallel/src/jit_templates/templates/operation.h

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS
1414
# include <cuda/std/cstddef>
1515
# include <cuda/std/type_traits>
16+
# include <cuda/std/utility>
1617

1718
# include <cccl/c/types.h>
1819
#endif
@@ -23,10 +24,22 @@
2324
template <typename Tag, cccl_op_t_mapping Operation, cccl_type_info_mapping RetT, cccl_type_info_mapping... ArgTs>
2425
struct stateless_user_operation
2526
{
27+
// Note: The user provided C f unction (Operation.operation) must match the signature:
28+
// void (void* arg1, ..., void* argN, void* result_ptr)
2629
__device__ decltype(RetT)::Type operator()(decltype(ArgTs)::Type... args) const
2730
{
28-
return reinterpret_cast<decltype(RetT)::Type (*)(decltype(ArgTs)::Type...)>(Operation.operation)(
29-
std::move(args)...);
31+
using TargetCFuncPtr = void (*)(const decltype(args, void())*..., void*);
32+
33+
// Cast the stored operation pointer (assumed to be void* or compatible)
34+
auto c_func_ptr = reinterpret_cast<TargetCFuncPtr>(Operation.operation);
35+
36+
// Prepare storage for the result
37+
typename decltype(RetT)::Type result;
38+
39+
// Call the C function, casting argument addresses to void*
40+
c_func_ptr((const_cast<void*>(static_cast<const void*>(&args)))..., &result);
41+
42+
return result;
3043
}
3144
};
3245

@@ -42,8 +55,20 @@ struct stateful_user_operation
4255
user_operation_state<Operation.size, Operation.alignment> state;
4356
__device__ decltype(RetT)::Type operator()(decltype(ArgTs)::Type... args)
4457
{
45-
return reinterpret_cast<decltype(RetT)::Type (*)(void*, decltype(ArgTs)::Type...)>(
46-
Operation.operation)(&state, std::move(args)...);
58+
// Note: The user provided C function (Operation.operation) must match the signature:
59+
// void (void* state, void* arg1, ..., void* argN, void* result_ptr)
60+
using TargetCFuncPtr = void (*)(void*, const decltype(args, void())*..., void*);
61+
62+
// Cast the stored operation pointer (assumed to be void* or compatible)
63+
auto c_func_ptr = reinterpret_cast<TargetCFuncPtr>(Operation.operation);
64+
65+
// Prepare storage for the result
66+
typename decltype(RetT)::Type result;
67+
68+
// Call the C function, passing state address, casting argument addresses to void*, and result pointer
69+
c_func_ptr(&state, (const_cast<void*>(static_cast<const void*>(&args)))..., &result);
70+
71+
return result;
4772
}
4873
};
4974

c/parallel/src/kernels/operators.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ constexpr std::string_view binary_op_template = R"XXX(
3434
)XXX";
3535

3636
constexpr std::string_view stateless_binary_op_template = R"XXX(
37-
extern "C" __device__ {0} OP_NAME(LHS_T lhs, RHS_T rhs);
37+
extern "C" __device__ void OP_NAME(const void* lhs, const void* rhs, void* out);
3838
struct op_wrapper {{
3939
__device__ {0} operator()(LHS_T lhs, RHS_T rhs) const {{
40-
return OP_NAME(lhs, rhs);
40+
{0} ret;
41+
OP_NAME(&lhs, &rhs, &ret);
42+
return ret;
4143
}}
4244
}};
4345
)XXX";
@@ -46,11 +48,13 @@ constexpr std::string_view stateful_binary_op_template = R"XXX(
4648
struct __align__(OP_ALIGNMENT) op_state {{
4749
char data[OP_SIZE];
4850
}};
49-
extern "C" __device__ {0} OP_NAME(op_state *state, LHS_T lhs, RHS_T rhs);
51+
extern "C" __device__ void OP_NAME(void* state, const void* lhs, const void* rhs, void* out);
5052
struct op_wrapper {{
5153
op_state state;
5254
__device__ {0} operator()(LHS_T lhs, RHS_T rhs) {{
53-
return OP_NAME(&state, lhs, rhs);
55+
{0} ret;
56+
OP_NAME(&state, &lhs, &rhs, &ret);
57+
return ret;
5458
}}
5559
}};
5660
)XXX";
@@ -105,10 +109,12 @@ std::string make_kernel_user_unary_operator(std::string_view input_t, std::strin
105109
)XXX";
106110

107111
constexpr std::string_view stateless_op = R"XXX(
108-
extern "C" __device__ OUTPUT_T OP_NAME(INPUT_T val);
112+
extern "C" __device__ void OP_NAME(const void* val, void* result);
109113
struct op_wrapper {
110114
__device__ OUTPUT_T operator()(INPUT_T val) const {
111-
return OP_NAME(val);
115+
OUTPUT_T out;
116+
OP_NAME(&val, &out);
117+
return out;
112118
}
113119
};
114120
)XXX";
@@ -117,13 +123,15 @@ struct op_wrapper {
117123
struct __align__(OP_ALIGNMENT) op_state {
118124
char data[OP_SIZE];
119125
};
120-
extern "C" __device__ OUTPUT_T OP_NAME(op_state* state, INPUT_T val);
126+
extern "C" __device__ void OP_NAME(op_state* state, const void* val, void* result);
121127
struct op_wrapper
122128
{
123129
op_state state;
124130
__device__ OUTPUT_T operator()(INPUT_T val)
125131
{
126-
return OP_NAME(&state, val);
132+
OUTPUT_T out;
133+
OP_NAME(&state, &val, &out);
134+
return out;
127135
}
128136
};
129137

c/parallel/src/merge_sort.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ merge_sort_runtime_tuning_policy get_policy(int cc, int key_size)
127127
// TODO: we hardcode this value in order to make sure that the merge_sort test does not fail due to the memory op
128128
// assertions. This currently happens when we pass in items and keys of type uint8_t or int16_t, and for the custom
129129
// types test as well. This will be fixed after https://github.com/NVIDIA/cccl/issues/3570 is resolved.
130-
items_per_thread = 2;
130+
items_per_thread = 1;
131131

132132
return {block_size, items_per_thread, block_size * items_per_thread};
133133
}

c/parallel/test/test_for.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ C2H_TEST("for works with custom types", "[for]")
8080
operation_t op = make_operation("op",
8181
R"XXX(
8282
struct pair { short a; size_t b; };
83-
extern "C" __device__ void op(pair* a) {a->a++; a->b++;}
83+
extern "C" __device__ void op(void* a_ptr) {
84+
pair* a = static_cast<pair*>(a_ptr);
85+
a->a++;
86+
a->b++;
87+
}
8488
)XXX");
8589

8690
std::vector<pair> input(num_items, pair{short(1), size_t(1)});
@@ -106,7 +110,7 @@ struct invocation_counter_state_t
106110
int* d_counter;
107111
};
108112

109-
C2H_TEST("for works with stateful operators", "[for]")
113+
C2H_TEST("for_each works with stateful operators", "[for_each]")
110114
{
111115
const int num_items = 1 << 12;
112116
pointer_t<int> counter(1);
@@ -115,8 +119,9 @@ C2H_TEST("for works with stateful operators", "[for]")
115119
"op",
116120
R"XXX(
117121
struct invocation_counter_state_t { int* d_counter; };
118-
extern "C" __device__ void op(invocation_counter_state_t* state, int* a) {
119-
atomicAdd(state->d_counter, *a);
122+
extern "C" __device__ void op(void* state_ptr, void* a_ptr) {
123+
invocation_counter_state_t* state = static_cast<invocation_counter_state_t*>(state_ptr);
124+
atomicAdd(state->d_counter, *static_cast<int*>(a_ptr));
120125
}
121126
)XXX",
122127
op_state);
@@ -137,7 +142,7 @@ struct large_state_t
137142
int y, z, a;
138143
};
139144

140-
C2H_TEST("for works with large stateful operators", "[for]")
145+
C2H_TEST("for_each works with large stateful operators", "[for_each]")
141146
{
142147
const int num_items = 1 << 12;
143148
pointer_t<int> counter(1);
@@ -151,8 +156,9 @@ struct large_state_t
151156
int* d_counter;
152157
int y, z, a;
153158
};
154-
extern "C" __device__ void op(large_state_t* state, int* a) {
155-
atomicAdd(state->d_counter, *a);
159+
extern "C" __device__ void op(void* state_ptr, void* a_ptr) {
160+
large_state_t* state = static_cast<large_state_t*>(state_ptr);
161+
atomicAdd(state->d_counter, *static_cast<int*>(a_ptr));
156162
}
157163
)XXX",
158164
op_state);

c/parallel/test/test_merge_sort.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,11 @@ C2H_TEST("DeviceMergeSort:SortPairsCopy works with custom types", "[merge_sort]"
185185
operation_t op = make_operation(
186186
"op",
187187
"struct key_pair { short a; size_t b; };\n"
188-
"extern \"C\" __device__ bool op(key_pair lhs, key_pair rhs) {\n"
189-
" return lhs.a == rhs.a ? lhs.b < rhs.b : lhs.a < rhs.a;\n"
188+
"extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n"
189+
" key_pair* lhs = static_cast<key_pair*>(lhs_ptr);\n"
190+
" key_pair* rhs = static_cast<key_pair*>(rhs_ptr);\n"
191+
" bool* out = static_cast<bool*>(out_ptr);\n"
192+
" *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n"
190193
"}");
191194
const std::vector<short> a = generate<short>(num_items);
192195
const std::vector<size_t> b = generate<size_t>(num_items);
@@ -364,7 +367,9 @@ struct large_key_pair
364367
char c[100];
365368
};
366369

367-
C2H_TEST("DeviceMergeSort:SortPairsCopy fails to build for large types due to no vsmem", "[merge_sort]")
370+
// TODO: We no longer fail to build for large types due to no vsmem. Instead, the build passes,
371+
// but we get a ptxas error about the kernel using too much shared memory.
372+
/* C2H_TEST("DeviceMergeSort:SortPairsCopy fails to build for large types due to no vsmem", "[merge_sort]")
368373
{
369374
const size_t num_items = 1;
370375
operation_t op = make_operation(
@@ -411,3 +416,4 @@ C2H_TEST("DeviceMergeSort:SortPairsCopy fails to build for large types due to no
411416
libcudacxx_path,
412417
ctk_path));
413418
}
419+
*/

c/parallel/test/test_reduce.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ C2H_TEST("Reduce works with custom types", "[reduce]")
8080
operation_t op = make_operation(
8181
"op",
8282
"struct pair { short a; size_t b; };\n"
83-
"extern \"C\" __device__ pair op(pair lhs, pair rhs) {\n"
84-
" return pair{ lhs.a + rhs.a, lhs.b + rhs.b };\n"
83+
"extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n"
84+
" pair* lhs = static_cast<pair*>(lhs_ptr);\n"
85+
" pair* rhs = static_cast<pair*>(rhs_ptr);\n"
86+
" pair* out = static_cast<pair*>(out_ptr);\n"
87+
" *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n"
8588
"}");
8689
const std::vector<short> a = generate<short>(num_items);
8790
const std::vector<size_t> b = generate<size_t>(num_items);
@@ -203,9 +206,12 @@ C2H_TEST("Reduce works with stateful operators", "[reduce]")
203206
stateful_operation_t<invocation_counter_state_t> op = make_operation(
204207
"op",
205208
"struct invocation_counter_state_t { int* d_counter; };\n"
206-
"extern \"C\" __device__ int op(invocation_counter_state_t *state, int a, int b) {\n"
209+
"extern \"C\" __device__ void op(void* state_ptr, void* a_ptr, void* b_ptr, void* out_ptr) {\n"
210+
" invocation_counter_state_t* state = static_cast<invocation_counter_state_t*>(state_ptr);\n"
207211
" atomicAdd(state->d_counter, 1);\n"
208-
" return a + b;\n"
212+
" int a = *static_cast<int*>(a_ptr);\n"
213+
" int b = *static_cast<int*>(b_ptr);\n"
214+
" *static_cast<int*>(out_ptr) = a + b;\n"
209215
"}",
210216
invocation_counter_state_t{counter.ptr});
211217

c/parallel/test/test_scan.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,11 @@ C2H_TEST("Scan works with custom types", "[scan]")
130130
operation_t op = make_operation(
131131
"op",
132132
"struct pair { short a; size_t b; };\n"
133-
"extern \"C\" __device__ pair op(pair lhs, pair rhs) {\n"
134-
" return pair{ lhs.a + rhs.a, lhs.b + rhs.b };\n"
133+
"extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n"
134+
" pair* lhs = static_cast<pair*>(lhs_ptr);\n"
135+
" pair* rhs = static_cast<pair*>(rhs_ptr);\n"
136+
" pair* out = static_cast<pair*>(out_ptr);\n"
137+
" *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n"
135138
"}");
136139
const std::vector<short> a = generate<short>(num_items);
137140
const std::vector<size_t> b = generate<size_t>(num_items);

c/parallel/test/test_segmented_reduce.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,11 @@ struct pair {{
207207
short a;
208208
size_t b;
209209
}};
210-
extern "C" __device__ pair {0}(pair lhs, pair rhs) {{
211-
return pair{{ lhs.a + rhs.a, lhs.b + rhs.b }};
210+
extern "C" __device__ void {0}(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {{
211+
pair* lhs = static_cast<pair*>(lhs_ptr);
212+
pair* rhs = static_cast<pair*>(rhs_ptr);
213+
pair* out = static_cast<pair*>(out_ptr);
214+
*out = pair{{ lhs->a + rhs->a, lhs->b + rhs->b }};
212215
}}
213216
)XXX";
214217
std::string plus_pair_op_src = std::format(plus_pair_op_template, device_op_name);

c/parallel/test/test_transform.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@ C2H_TEST("Transform works with output of different type", "[transform]")
107107
operation_t op = make_operation(
108108
"op",
109109
"struct pair { short a; size_t b; };\n"
110-
"extern \"C\" __device__ pair op(int x) {\n"
111-
" return pair{ short(x), size_t(x) };\n"
110+
"extern \"C\" __device__ void op(void* x_ptr, void* out_ptr) {\n"
111+
" int* x = static_cast<int*>(x_ptr);\n"
112+
" pair* out = static_cast<pair*>(out_ptr);\n"
113+
" *out = pair{ short(*x), size_t(*x) };\n"
112114
"}");
113115
const std::vector<int> input = generate<int>(num_items);
114116
std::vector<pair> expected(num_items);
@@ -134,8 +136,10 @@ C2H_TEST("Transform works with custom types", "[transform]")
134136
operation_t op = make_operation(
135137
"op",
136138
"struct pair { short a; size_t b; };\n"
137-
"extern \"C\" __device__ pair op(pair x) {\n"
138-
" return pair{ x.a * 2, x.b * 2 };\n"
139+
"extern \"C\" __device__ void op(void* x_ptr, void* out_ptr) {\n"
140+
" pair* x = static_cast<pair*>(x_ptr);\n"
141+
" pair* out = static_cast<pair*>(out_ptr);\n"
142+
" *out = pair{ x->a * 2, x->b * 2 };\n"
139143
"}");
140144
const std::vector<short> a = generate<short>(num_items);
141145
const std::vector<size_t> b = generate<size_t>(num_items);
@@ -219,8 +223,11 @@ C2H_TEST("Transform with binary operator", "[transform]")
219223

220224
operation_t op = make_operation(
221225
"op",
222-
"extern \"C\" __device__ int op(int x, int y) {\n"
223-
" return (x > y) ? x : y;\n"
226+
"extern \"C\" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr ) {\n"
227+
" int* x = static_cast<int*>(x_ptr);\n"
228+
" int* y = static_cast<int*>(y_ptr);\n"
229+
" int* out = static_cast<int*>(out_ptr);\n"
230+
" *out = (*x > *y) ? *x : *y;\n"
224231
"}");
225232

226233
binary_transform(input1_ptr, input2_ptr, output_ptr, num_items, op);
@@ -250,8 +257,11 @@ C2H_TEST("Binary transform with one iterator", "[transform]")
250257

251258
operation_t op = make_operation(
252259
"op",
253-
"extern \"C\" __device__ int op(int x, int y) {\n"
254-
" return (x > y) ? x : y;\n"
260+
"extern \"C\" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr) {\n"
261+
" int* x = static_cast<int*>(x_ptr);\n"
262+
" int* y = static_cast<int*>(y_ptr);\n"
263+
" int* out = static_cast<int*>(out_ptr);\n"
264+
" *out = (*x > *y) ? *x : *y;\n"
255265
"}");
256266

257267
binary_transform(input1_ptr, input2_it, output_ptr, num_items, op);

c/parallel/test/test_unique_by_key.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,11 @@ C2H_TEST("DeviceSelect::UniqueByKey works with custom types", "[device][select_u
221221
operation_t op = make_operation(
222222
"op",
223223
"struct key_pair { short a; size_t b; };\n"
224-
"extern \"C\" __device__ bool op(key_pair lhs, key_pair rhs) {\n"
225-
" return lhs.a == rhs.a && lhs.b == rhs.b;\n"
224+
"extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n"
225+
" key_pair* lhs = static_cast<key_pair*>(lhs_ptr);\n"
226+
" key_pair* rhs = static_cast<key_pair*>(rhs_ptr);\n"
227+
" bool* out = static_cast<bool*>(out_ptr);\n"
228+
" *out = (lhs->a == rhs->a && lhs->b == rhs->b);\n"
226229
"}");
227230
const std::vector<short> a = generate<short>(num_items);
228231
const std::vector<size_t> b = generate<size_t>(num_items);

0 commit comments

Comments
 (0)