5
5
#include < cstdint>
6
6
#include < cuda.h>
7
7
8
+ #include < string>
9
+
8
10
typedef uint64_t u64 ;
9
11
10
12
// CHECK: void exec_kernel(dpct::kernel_function cuFunc, dpct::kernel_library cuMod, dpct::queue_ptr stream) {
@@ -15,9 +17,10 @@ void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
15
17
// verify the conversion from dpct::kernel_library to uint64_t
16
18
mod = (u64 )cuMod;
17
19
20
+ std::string kernel_name{" kfoo" };
18
21
// verify the conversion from uint64_t to dpct::kernel_library
19
- // CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, "kfoo" );
20
- cuModuleGetFunction (&cuFunc, (CUmodule)mod, " kfoo " );
22
+ // CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, kernel_name.c_str() );
23
+ cuModuleGetFunction (&cuFunc, (CUmodule)mod, kernel_name. c_str () );
21
24
22
25
// verify the conversion from dpct::kernel_function to uint64_t
23
26
function = (u64 )cuFunc;
@@ -28,3 +31,18 @@ void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
28
31
// CHECK: dpct::invoke_kernel_function((dpct::kernel_function)function, *stream, sycl::range<3>(100, 100, 100), sycl::range<3>(100, 100, 100), 1024, NULL, config);
29
32
cuLaunchKernel ((CUfunction)function, 100 , 100 , 100 , 100 , 100 , 100 , 1024 , stream, NULL , config);
30
33
}
34
+
35
+ class CString {
36
+ private:
37
+ char *str;
38
+ public:
39
+ CString (): str(NULL ) {};
40
+ operator const char * () const { return str; }
41
+
42
+ operator char * () { return str; }
43
+ };
44
+
45
+ void test_casting (CUmodule mod, CUfunction func, const CString &name) {
46
+ // CHECK: func = dpct::get_kernel_function(mod, (const char *)name);
47
+ cuModuleGetFunction (&func, mod, name);
48
+ }
0 commit comments