Skip to content

[SYCLomatic] Fix the parse the const memory througth parameter of function #2922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clang/examples/DPCT/Runtime/cudaGetSymbolAddress.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
void test(void **pDev, const void *symbol) {
#define MAX_CONST_SIZE 1024
__constant__ char symbol[MAX_CONST_SIZE];

void test(void **pDev) {
// Start
cudaGetSymbolAddress(pDev /*void ***/, symbol /*const void **/);
// End
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/DPCT/DPCTOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ DPCT_ENUM_OPTION(
DPCT_OPTION_ENUM_VALUE(
"device_global", int(ExperimentalFeatures::Exp_DeviceGlobal),
"Experimental extension that allows device scoped memory "
"allocations into SYCL that can\n"
"allocations into SYCL that can "
"be accessed within a kernel using syntax similar to C++ global "
"variables.\n",
false),
Expand Down
13 changes: 4 additions & 9 deletions clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,12 @@ std::string MathFuncNameRewriter::getNewFuncName() {

auto ContextFD = getImmediateOuterFuncDecl(Call);
if (NamespaceStr == "std" && ContextFD &&
!ContextFD->hasAttr<CUDADeviceAttr>() &&
!ContextFD->hasAttr<CUDAGlobalAttr>()) {
!isGlobalOrDeviceFuncDecl(ContextFD)) {
return "";
}
// For device functions
else if ((FD->hasAttr<CUDADeviceAttr>() && !FD->hasAttr<CUDAHostAttr>()) ||
(ContextFD && (ContextFD->hasAttr<CUDADeviceAttr>() ||
ContextFD->hasAttr<CUDAGlobalAttr>()))) {
(ContextFD && isGlobalOrDeviceFuncDecl(ContextFD))) {
if (SourceCalleeName == "abs") {
// further check the type of the args.
if (!Call->getArg(0)->getType()->isIntegerType()) {
Expand Down Expand Up @@ -333,15 +331,12 @@ std::optional<std::string> MathSimulatedRewriter::rewrite() {
}

auto ContextFD = getImmediateOuterFuncDecl(Call);
if (NamespaceStr == "std" && ContextFD &&
!ContextFD->hasAttr<CUDADeviceAttr>() &&
!ContextFD->hasAttr<CUDAGlobalAttr>()) {
if (NamespaceStr == "std" && ContextFD && !isGlobalOrDeviceFuncDecl(ContextFD)) {
return {};
}

if (!FD->hasAttr<CUDADeviceAttr>() && ContextFD &&
!ContextFD->hasAttr<CUDADeviceAttr>() &&
!ContextFD->hasAttr<CUDAGlobalAttr>())
!isGlobalOrDeviceFuncDecl(ContextFD))
return Base::rewrite();

// Do not need to report warnings for pow, funnelshift, or drcp migrations
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ inline auto IsDirectCallerPureHost = [](const CallExpr *C) -> bool {
}
if (!ContextFD)
return false;
if (!ContextFD->getAttr<CUDADeviceAttr>() &&
!ContextFD->getAttr<CUDAGlobalAttr>()) {
if (!isGlobalOrDeviceFuncDecl(ContextFD)) {
return true;
}
return false;
Expand Down
13 changes: 10 additions & 3 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5020,8 +5020,7 @@ void DeviceFunctionDeclRule::runRule(

// We need skip lambda in host code, but cannot skip lambda in device code.
if (const FunctionDecl *OuterMostFD = findTheOuterMostFunctionDecl(FD);
OuterMostFD && (!OuterMostFD->hasAttr<CUDADeviceAttr>() &&
!OuterMostFD->hasAttr<CUDAGlobalAttr>()))
OuterMostFD && !isGlobalOrDeviceFuncDecl(OuterMostFD))
return;

if (FD->isVariadic()) {
Expand Down Expand Up @@ -6813,9 +6812,17 @@ void MemoryMigrationRule::getSymbolAddressMigration(
ExprAnalysis EA;
EA.analyze(C->getArg(0));
auto StmtStrArg0 = EA.getReplacedString();
const DeclRefExpr *Arg =
dyn_cast<DeclRefExpr>(C->getArg(1)->IgnoreImplicitAsWritten());
const VarDecl *VD = dyn_cast<VarDecl>(Arg->getDecl());
EA.analyze(C->getArg(1));
auto StmtStrArg1 = EA.getReplacedString();
Replacement = "*(" + StmtStrArg0 + ")" + " = " + StmtStrArg1 + ".get_ptr()";
if (VD && VD->isLocalVarDeclOrParm()) {
StmtStrArg1 = "const_cast<void *>(" + StmtStrArg1 + ")";
} else {
StmtStrArg1 += ".get_ptr()";
}
Replacement = "*(" + StmtStrArg0 + ")" + " = " + StmtStrArg1;
requestFeature(HelperFeatureEnum::device_ext);
emplaceTransformation(new ReplaceStmt(C, std::move(Replacement)));
}
Expand Down
20 changes: 14 additions & 6 deletions clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,17 @@ void MemVarRefMigrationRule::runRule(const MatchFinder::MatchResult &Result) {
}
}
}
if (!HasTypeCasted && Decl->hasAttr<CUDAConstantAttr>() &&
(MemVarRef->getType()->getTypeClass() ==
Type::TypeClass::ConstantArray)) {
auto FD = dpct::DpctGlobalInfo::findAncestor<FunctionDecl>(MemVarRef);
auto CE = dpct::DpctGlobalInfo::findAncestor<CallExpr>(MemVarRef);
if (auto VD =dyn_cast<VarDecl>(MemVarRef->getDecl()); FD && VD &&
!VD->isLocalVarDeclOrParm() &&
!isGlobalOrDeviceFuncDecl(FD)) {
if (CE &&
!DpctGlobalInfo::isInCudaPath(CE->getCalleeDecl()->getBeginLoc()))
emplaceTransformation(new InsertAfterStmt(MemVarRef, ".get_ptr()"));
} else if (!HasTypeCasted && Decl->hasAttr<CUDAConstantAttr>() &&
(MemVarRef->getType()->getTypeClass() ==
Type::TypeClass::ConstantArray)) {
const Expr *RHS = getRHSOfTheNonConstAssignedVar(MemVarRef);
if (RHS) {
auto Range = GetReplRange(RHS);
Expand All @@ -235,7 +243,7 @@ void MemVarRefMigrationRule::runRule(const MatchFinder::MatchResult &Result) {
if (VD == nullptr)
return;
auto Var = Global.findMemVarInfo(VD);
if (Func->hasAttr<CUDAGlobalAttr>() || Func->hasAttr<CUDADeviceAttr>()) {
if (isGlobalOrDeviceFuncDecl(Func)) {
if (DpctGlobalInfo::useGroupLocalMemory() &&
VD->hasAttr<CUDASharedAttr>() && VD->getStorageClass() != SC_Extern) {
if (!Var)
Expand Down Expand Up @@ -829,7 +837,7 @@ void MemVarAnalysisRule::runRule(const MatchFinder::MatchResult &Result) {
return;
}
auto Var = MemVarInfo::buildMemVarInfo(VD);
if (Func->hasAttr<CUDAGlobalAttr>() || Func->hasAttr<CUDADeviceAttr>()) {
if (isGlobalOrDeviceFuncDecl(Func)) {
if (!(DpctGlobalInfo::useGroupLocalMemory() &&
VD->hasAttr<CUDASharedAttr>() &&
VD->getStorageClass() != SC_Extern)) {
Expand Down Expand Up @@ -1025,7 +1033,7 @@ void ZeroLengthArrayRule::runRule(const MatchFinder::MatchResult &Result) {
const clang::FunctionDecl *FD = DpctGlobalInfo::getParentFunction(TL);
if (FD) {
// Check if the array is in device code
if (!(FD->getAttr<CUDADeviceAttr>()) && !(FD->getAttr<CUDAGlobalAttr>()))
if (!isGlobalOrDeviceFuncDecl(FD))
return;
}
}
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/DPCT/RulesLang/RulesLangTexture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ const Expr *TextureRule::getAssignedBO(const Expr *E, ASTContext &Context) {
bool TextureRule::processTexVarDeclInDevice(const VarDecl *VD) {
if (auto FD =
dyn_cast_or_null<FunctionDecl>(VD->getParentFunctionOrMethod())) {
if (FD->hasAttr<CUDAGlobalAttr>() || FD->hasAttr<CUDADeviceAttr>()) {
if (isGlobalOrDeviceFuncDecl(FD)) {
auto Tex = DpctGlobalInfo::getInstance().insertTextureInfo(VD);

auto DataType = Tex->getType()->getDataType();
Expand Down Expand Up @@ -1009,8 +1009,7 @@ void TextureRule::runRule(const MatchFinder::MatchResult &Result) {
return;
}
if (auto FD = DpctGlobalInfo::getParentFunction(TL)) {
if ((FD->hasAttr<CUDAGlobalAttr>() || FD->hasAttr<CUDADeviceAttr>()) &&
!DpctGlobalInfo::useExtBindlessImages()) {
if (isGlobalOrDeviceFuncDecl(FD) && !DpctGlobalInfo::useExtBindlessImages()) {
return;
}
}
Expand Down
6 changes: 2 additions & 4 deletions clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,8 +1054,7 @@ void CubRule::processCubTypeDefOrUsing(const TypedefNameDecl *TD) {
MapNames::getClNamespace() + "sub_group", SM));
} else if (CanonicalTypeStr.find("Block") != std::string::npos) {
auto DeviceFuncDecl = DpctGlobalInfo::findAncestor<FunctionDecl>(TD);
if (DeviceFuncDecl && (DeviceFuncDecl->hasAttr<CUDADeviceAttr>() ||
DeviceFuncDecl->hasAttr<CUDAGlobalAttr>())) {
if (DeviceFuncDecl && isGlobalOrDeviceFuncDecl(DeviceFuncDecl)) {
if (auto DI = DeviceFunctionDecl::LinkRedecls(DeviceFuncDecl)) {
auto &Map = DpctGlobalInfo::getInstance().getCubPlaceholderIndexMap();
Map.insert({PlaceholderIndex, DI});
Expand Down Expand Up @@ -1692,8 +1691,7 @@ void CubRule::processTypeLoc(const TypeLoc *TL) {
} else if (TypeName.find("class cub::BlockScan") == 0 ||
TypeName.find("class cub::BlockReduce") == 0) {
auto DeviceFuncDecl = DpctGlobalInfo::findAncestor<FunctionDecl>(TL);
if (DeviceFuncDecl && (DeviceFuncDecl->hasAttr<CUDADeviceAttr>() ||
DeviceFuncDecl->hasAttr<CUDAGlobalAttr>())) {
if (DeviceFuncDecl && isGlobalOrDeviceFuncDecl(DeviceFuncDecl)) {
if (auto DI = DeviceFunctionDecl::LinkRedecls(DeviceFuncDecl)) {
auto &Map = DpctGlobalInfo::getInstance().getCubPlaceholderIndexMap();
Map.insert({PlaceholderIndex, DI});
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesLangLib/ThrustAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void ThrustAPIRule::thrustFuncMigration(const MatchFinder::MatchResult &Result,
// thrust::count, thrust::equal) called in device function , should be
// migrated to oneapi::dpl APIs without a policy on the SYCL side
if (auto FD = DpctGlobalInfo::getParentFunction(CE)) {
if (FD->hasAttr<CUDAGlobalAttr>() || FD->hasAttr<CUDADeviceAttr>()) {
if (isGlobalOrDeviceFuncDecl(FD)) {
if (hasExecutionPolicy) {
emplaceTransformation(removeArg(CE, 0, *Result.SourceManager));
}
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/DPCT/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,11 @@ bool isCudaMemoryAllocation(const DeclRefExpr *Arg, const CallExpr *CE) {
return false;
}

bool isGlobalOrDeviceFuncDecl(const FunctionDecl *FD) {
if (FD->hasAttr<CUDADeviceAttr>() || FD->hasAttr<CUDAGlobalAttr>())
return true;
return false;
}
/// This function traverses all the nodes in the AST represented by \param Root
/// in a depth-first manner, until the node \param Sentinal is reached, to check
/// if the pointer \param Arg to a piece of memory is used as lvalue after the
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ bool isTypeInAnalysisScope(const clang::Type *TypePtr);
bool isCubVar(const clang::VarDecl *VD);
bool isCubTempStorageType(QualType T);
bool isCubCollectiveRecordType(QualType T);
bool isGlobalOrDeviceFuncDecl(const FunctionDecl *FD);
bool isExprUsed(const clang::Expr *E, bool &Result);
bool isUserDefinedDecl(const clang::Decl *D);
bool isLambda(const clang::FunctionDecl *FD);
Expand Down
52 changes: 52 additions & 0 deletions clang/test/dpct/cuda_const_pass_by_param.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

// RUN: dpct --format-range=none --usm-level=none -out-root %T/cuda_const_pass_by_param %s --cuda-include-path="%cuda-path/include" --sycl-named-lambda -- -x cuda --cuda-host-only
// RUN: FileCheck %s --match-full-lines --input-file %T/cuda_const_pass_by_param/cuda_const_pass_by_param.dp.cpp
// RUN: %if build_lit %{icpx -c -fsycl %T/cuda_const_pass_by_param/cuda_const_pass_by_param.dp.cpp -o %T/cuda_const_pass_by_param/cuda_const_pass_by_param.dp.o %}
#include <cstdio>
#include <cuda_runtime.h>

#define MAX_CONST_SIZE 1024
__constant__ char device_const_buffer[MAX_CONST_SIZE];


__host__ void* qudaGetSymbolAddress(const void* symbol) {

void* ptr;
// CHECK: *(&ptr) = const_cast<void *>(symbol);
cudaGetSymbolAddress(&ptr, symbol);
return ptr;

}

__host__ void* qudaGetSymbolAddress2() {

void* ptr;
// CHECK: *(&ptr) = device_const_buffer.get_ptr();
cudaGetSymbolAddress(&ptr, device_const_buffer);
return ptr;

}


template <typename T>
__host__ void process_buffer(T* data) {

if(data) printf("Processed: %f\n", static_cast<float>(data[0]));
}


int main() {
float h_data[256];
for(int i=0; i<256; i++) h_data[i] = i*1.0f;
// CHECK: dpct::dpct_memcpy(device_const_buffer.get_ptr(), h_data, sizeof(h_data));
cudaMemcpyToSymbol(device_const_buffer, h_data, sizeof(h_data));
// CHECK: void* host_ptr = qudaGetSymbolAddress(device_const_buffer.get_ptr());
void* host_ptr = qudaGetSymbolAddress(device_const_buffer);
void* host_ptr2 = qudaGetSymbolAddress2();
process_buffer<float>(static_cast<float*>(host_ptr));
cudaDeviceSynchronize();

return 0;
}


3 changes: 1 addition & 2 deletions clang/test/dpct/help_option_check/lin/help_advanced.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ Advanced DPCT options
=bindless_images - Experimental extension that allows use of bindless images APIs.
=graph - Experimental extension that allows use of SYCL Graph APIs.
=non-uniform-groups - Experimental extension that allows use of non-uniform groups.
=device_global - Experimental extension that allows device scoped memory allocations into SYCL that can
be accessed within a kernel using syntax similar to C++ global variables.
=device_global - Experimental extension that allows device scoped memory allocations into SYCL that can be accessed within a kernel using syntax similar to C++ global variables.
=virtual_mem - Experimental extension that allows for mapping of an address range onto multiple allocations of physical memory.
=in_order_queue_events - Experimental extension that allows placing the event from the last command submission into the queue and setting an external event as an implicit dependence on the next command submitted to the queue.
=non-stdandard-sycl-builtins - Experimental extension that allows use of non standard SYCL builtin functions.
Expand Down
3 changes: 1 addition & 2 deletions clang/test/dpct/help_option_check/lin/help_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ All DPCT options
=bindless_images - Experimental extension that allows use of bindless images APIs.
=graph - Experimental extension that allows use of SYCL Graph APIs.
=non-uniform-groups - Experimental extension that allows use of non-uniform groups.
=device_global - Experimental extension that allows device scoped memory allocations into SYCL that can
be accessed within a kernel using syntax similar to C++ global variables.
=device_global - Experimental extension that allows device scoped memory allocations into SYCL that can be accessed within a kernel using syntax similar to C++ global variables.
=virtual_mem - Experimental extension that allows for mapping of an address range onto multiple allocations of physical memory.
=in_order_queue_events - Experimental extension that allows placing the event from the last command submission into the queue and setting an external event as an implicit dependence on the next command submitted to the queue.
=non-stdandard-sycl-builtins - Experimental extension that allows use of non standard SYCL builtin functions.
Expand Down
8 changes: 4 additions & 4 deletions clang/test/dpct/kernel-call.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ void run_foo4(dim3 c, dim3 d) {
//CHECK-NEXT: my_kernel(result_acc_ct0.get_raw_pointer(), resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
//CHECK-NEXT: });
//CHECK-NEXT: });
//CHECK-NEXT: printf("%f ", result[10]);
//CHECK-NEXT: printf("%f ", result.get_ptr()[10]);
//CHECK-NEXT:}
__managed__ float result[32];
__global__ void my_kernel(float* result) {
Expand All @@ -432,7 +432,7 @@ int run_foo5 () {
//CHECK-NEXT: my_kernel(result2_acc_ct0.get_raw_pointer(), resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
//CHECK-NEXT: });
//CHECK-NEXT: });
//CHECK-NEXT: printf("%f ", result2[10]);
//CHECK-NEXT: printf("%f ", result2.get_ptr()[10]);
//CHECK-NEXT:}
__managed__ float result2[32];
int run_foo6 () {
Expand All @@ -453,7 +453,7 @@ int run_foo6 () {
//CHECK-NEXT: my_kernel(result3_acc_ct0.get_raw_pointer(), resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
//CHECK-NEXT: });
//CHECK-NEXT: });
//CHECK-NEXT: printf("%f ", result3[0]);
//CHECK-NEXT: printf("%f ", result3.get_ptr()[0]);
//CHECK-NEXT:}
__managed__ float result3;
int run_foo7 () {
Expand Down Expand Up @@ -482,7 +482,7 @@ int run_foo7 () {
//CHECK-NEXT: my_kernel2(in_ct0, out_acc_ct1.get_raw_pointer());
//CHECK-NEXT: });
//CHECK-NEXT: });
//CHECK-NEXT: printf("%f ", out[0]);
//CHECK-NEXT: printf("%f ", out.get_ptr()[0]);
//CHECK-NEXT:}

__managed__ float in;
Expand Down
8 changes: 4 additions & 4 deletions clang/test/dpct/kernel-usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ int main() {
// CHECK-NEXT: my_kernel(result_ct0, resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: printf("%f ", result[10]);
// CHECK-NEXT: printf("%f ", result.get_ptr()[10]);
// CHECK-NEXT:}
__managed__ __device__ float result[32];
__global__ void my_kernel(float* result) {
Expand All @@ -92,7 +92,7 @@ int run_foo5 () {
// CHECK-NEXT: my_kernel(result2_ct0, resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: printf("%f ", result2[10]);
// CHECK-NEXT: printf("%f ", result2.get_ptr()[10]);
// CHECK-NEXT:}
__managed__ float result2[32];
int run_foo6 () {
Expand All @@ -114,7 +114,7 @@ int run_foo6 () {
// CHECK-NEXT: my_kernel(result3_ct0, resultInGroup_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: printf("%f ", result3[0]);
// CHECK-NEXT: printf("%f ", result3.get_ptr()[0]);
// CHECK-NEXT:}
__managed__ float result3;
int run_foo7 () {
Expand Down Expand Up @@ -142,7 +142,7 @@ int run_foo7 () {
// CHECK-NEXT: my_kernel2(in_ct0, out_ct1);
// CHECK-NEXT: });
// CHECK-NEXT: });
// CHECK-NEXT: printf("%f ", out[0]);
// CHECK-NEXT: printf("%f ", out.get_ptr()[0]);
// CHECK-NEXT:}
__managed__ float in;
__managed__ float out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
// CUDAGETSYMBOLADDRESS: CUDA API:
// CUDAGETSYMBOLADDRESS-NEXT: cudaGetSymbolAddress(pDev /*void ***/, symbol /*const void **/);
// CUDAGETSYMBOLADDRESS-NEXT: Is migrated to:
// CUDAGETSYMBOLADDRESS-NEXT: *(pDev) = symbol.get_ptr();
// CUDAGETSYMBOLADDRESS-NEXT: *(pDev) = const_cast<void *>(symbol);

// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cudaGetSymbolSize | FileCheck %s -check-prefix=CUDAGETSYMBOLSIZE
// CUDAGETSYMBOLSIZE: CUDA API:
Expand Down
Loading