Skip to content

Commit 9af6b57

Browse files
author
the-slow-one
authored
[SYCLomatic] Cast the function name argument for get_kernel_function (#2806)
1 parent 0dfddb2 commit 9af6b57

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ template <class SubExprT> class CastIfNotSameExprPrinter {
149149
clang::QualType ArgType = InputArg->getType().getCanonicalType();
150150
ArgType.removeLocalFastQualifiers(clang::Qualifiers::CVRMask);
151151
bool NeedParen = false;
152+
std::cout << "Arg type: " << ArgType.getAsString() << "\n";
153+
std::cout << "Given type " << TypeInfo << "\n";
152154
if (ArgType.getAsString() != TypeInfo) {
153155
NeedParen = needExtraParens(SubExpr);
154156
Stream << "(" << TypeInfo << ")";

clang/lib/DPCT/RulesLang/APINamesDriver.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ FEATURE_REQUEST_FACTORY(
4747
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cuModuleGetFunction", DEREF(0),
4848
CALL(MapNames::getDpctNamespace() +
4949
"get_kernel_function",
50-
ARG(1), ARG(2)))))
50+
ARG(1), CAST_IF_NOT_SAME(makeLiteral("const char *"), ARG(2))))))
5151

5252
FEATURE_REQUEST_FACTORY(
5353
HelperFeatureEnum::device_ext,

clang/test/dpct/kernel-function-typecast.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <cstdint>
66
#include <cuda.h>
77

8+
#include <string>
9+
810
typedef uint64_t u64;
911

1012
// CHECK: void exec_kernel(dpct::kernel_function cuFunc, dpct::kernel_library cuMod, dpct::queue_ptr stream) {
@@ -15,9 +17,10 @@ void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
1517
// verify the conversion from dpct::kernel_library to uint64_t
1618
mod = (u64)cuMod;
1719

20+
std::string kernel_name{"kfoo"};
1821
// verify the conversion from uint64_t to dpct::kernel_library
19-
// CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, "kfoo");
20-
cuModuleGetFunction(&cuFunc, (CUmodule)mod, "kfoo");
22+
// CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, kernel_name.c_str());
23+
cuModuleGetFunction(&cuFunc, (CUmodule)mod, kernel_name.c_str());
2124

2225
// verify the conversion from dpct::kernel_function to uint64_t
2326
function = (u64)cuFunc;
@@ -28,3 +31,18 @@ void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
2831
// CHECK: dpct::invoke_kernel_function((dpct::kernel_function)function, *stream, sycl::range<3>(100, 100, 100), sycl::range<3>(100, 100, 100), 1024, NULL, config);
2932
cuLaunchKernel((CUfunction)function, 100, 100, 100, 100, 100, 100, 1024, stream, NULL, config);
3033
}
34+
35+
class CString {
36+
private:
37+
char *str;
38+
public:
39+
CString(): str(NULL) {};
40+
operator const char* () const { return str; }
41+
42+
operator char* () { return str; }
43+
};
44+
45+
void test_casting(CUmodule mod, CUfunction func, const CString &name) {
46+
// CHECK: func = dpct::get_kernel_function(mod, (const char *)name);
47+
cuModuleGetFunction(&func, mod, name);
48+
}

0 commit comments

Comments
 (0)