Skip to content

Commit 4f31d63

Browse files
authored
[SYCLomatic] enhance cub scan API migration in case thrust::plus/maximum/minimum() used as BinaryOp functor. (#2852)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent f3c921b commit 4f31d63

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -846,11 +846,17 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
846846
auto processOperatorExpr = [&](const Expr *Obj) {
847847
std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName(
848848
Obj->getType().getCanonicalType());
849-
if (OpType == "cub::Sum" || OpType == "cuda::std::plus<void>") {
849+
if (OpType.find("cub::Sum") != std::string::npos ||
850+
OpType.find("cuda::std::plus") != std::string::npos ||
851+
OpType.find("thrust::plus") != std::string::npos) {
850852
OpRepl = MapNames::getClNamespace() + "plus<>()";
851-
} else if (OpType == "cub::Max" || OpType == "cuda::maximum<void>") {
853+
} else if (OpType.find("cub::Max") != std::string::npos ||
854+
OpType.find("cuda::maximum") != std::string::npos ||
855+
OpType.find("thrust::maximum") != std::string::npos) {
852856
OpRepl = MapNames::getClNamespace() + "maximum<>()";
853-
} else if (OpType == "cub::Min" || OpType == "cuda::minimum<void>") {
857+
} else if (OpType.find("cub::Min") != std::string::npos ||
858+
OpType.find("cuda::minimum") != std::string::npos ||
859+
OpType.find("thrust::minimum") != std::string::npos) {
854860
OpRepl = MapNames::getClNamespace() + "minimum<>()";
855861
}
856862
};
@@ -861,17 +867,21 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
861867
} else {
862868
auto CtorArg = Op->getArg(0)->IgnoreImplicitAsWritten();
863869
if (auto DRE = dyn_cast<DeclRefExpr>(CtorArg)) {
864-
auto D = DRE->getDecl();
865-
if (!D)
866-
return OpRepl;
867-
std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName(
868-
D->getType().getCanonicalType());
869-
if (OpType == "cub::Sum" || OpType == "cub::Max" ||
870-
OpType == "cub::Min" || OpType == "cuda::std::plus<void>" ||
871-
OpType == "cuda::maximum<void>" ||
872-
OpType == "cuda::minimum<void>") {
873-
ExprAnalysis EA(Operator);
874-
OpRepl = EA.getReplacedString();
870+
if (auto D = DRE->getDecl()) {
871+
std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName(
872+
D->getType().getCanonicalType());
873+
if (OpType.find("cub::Sum") != std::string::npos ||
874+
OpType.find("cub::Max") != std::string::npos ||
875+
OpType.find("cub::Min") != std::string::npos ||
876+
OpType.find("cuda::std::plus") != std::string::npos ||
877+
OpType.find("cuda::maximum") != std::string::npos ||
878+
OpType.find("cuda::minimum") != std::string::npos ||
879+
OpType.find("thrust::plus") != std::string::npos ||
880+
OpType.find("thrust::maximum") != std::string::npos ||
881+
OpType.find("thrust::minimum") != std::string::npos) {
882+
ExprAnalysis EA(Operator);
883+
OpRepl = EA.getReplacedString();
884+
}
875885
}
876886
} else if (auto CXXTempObj = dyn_cast<CXXTemporaryObjectExpr>(CtorArg)) {
877887
processOperatorExpr(CXXTempObj);

clang/test/dpct/cub/blocklevel/blockscan.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,26 @@ __global__ void InclusiveScanKernel(int* data) {
135135
data[threadid] = output;
136136
}
137137

138+
// CHECK: void InclusiveScanKernelThrustFunctor(int* data) {
139+
// CHECK: int threadid = sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2);
140+
// CHECK: int input = data[threadid];
141+
// CHECK: int output = 0;
142+
// CHECK: output = sycl::inclusive_scan_over_group(sycl::ext::oneapi::this_work_item::get_work_group<3>(), input, sycl::plus<>());
143+
// CHECK: data[threadid] = output;
144+
// CHECK: }
145+
__global__ void InclusiveScanKernelThrustFunctor(int* data) {
146+
typedef cub::BlockScan<int, 4> BlockScan;
147+
148+
__shared__ typename BlockScan::TempStorage temp1;
149+
150+
int threadid = threadIdx.x;
151+
152+
int input = data[threadid];
153+
int output = 0;
154+
BlockScan(temp1).InclusiveScan(input, output, thrust::plus<>());
155+
data[threadid] = output;
156+
}
157+
138158
//CHECK: void InclusiveScanKernel_Max(int* data) {
139159
//CHECK-EMPTY:
140160
//CHECK-NEXT: int threadid = sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2);

0 commit comments

Comments
 (0)