Skip to content

Commit d6824ae

Browse files
authored
[SYCLomatic] Refine migration of cudaEventRecord/cuEventRecord/cudaEventRecordWithFlags with dpct helper function (#2909)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent 1d7099d commit d6824ae

8 files changed

+69
-294
lines changed

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 17 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,9 +3226,7 @@ void EventAPICallRule::handleEventRecordWithProfilingEnabled(
32263226
const CallExpr *CE, const MatchFinder::MatchResult &Result,
32273227
bool IsAssigned) {
32283228
int NumArgs = CE->getNumArgs();
3229-
const Expr *StreamArg = CE->getArg(NumArgs - 1);
32303229
if (NumArgs == 3) { // Special process for cudaEventRecordWithFlags().
3231-
StreamArg = CE->getArg(1);
32323230
auto APIName = CE->getDirectCallee()->getNameInfo().getName().getAsString();
32333231
const Expr *SecArg = CE->getArg(2);
32343232
ExprAnalysis Arg2EA(SecArg);
@@ -3241,195 +3239,32 @@ void EventAPICallRule::handleEventRecordWithProfilingEnabled(
32413239
emplaceTransformation(removeArg(CE, 2, *Result.SourceManager));
32423240
}
32433241

3244-
auto EventArg = CE->getArg(0);
3245-
ExprAnalysis StreamEA(StreamArg);
3246-
ExprAnalysis Arg0EA(EventArg);
3247-
auto StreamName = StreamEA.getReplacedString();
3248-
auto ArgName = Arg0EA.getReplacedString();
3249-
bool IsDefaultStream = isDefaultStream(StreamArg);
32503242
auto IndentLoc = CE->getBeginLoc();
32513243
auto &SM = DpctGlobalInfo::getSourceManager();
32523244

3253-
if (needExtraParens(EventArg)) {
3254-
ArgName = "(" + ArgName + ")";
3255-
}
3256-
3257-
if (needExtraParensInMemberExpr(StreamArg)) {
3258-
StreamName = "(" + StreamName + ")";
3259-
}
3260-
32613245
if (IndentLoc.isMacroID())
32623246
IndentLoc = SM.getExpansionLoc(IndentLoc);
32633247

3264-
if (IsAssigned) {
3265-
3266-
std::string StmtStr;
3267-
if (IsDefaultStream) {
3268-
if (isPlaceholderIdxDuplicated(CE))
3269-
return;
3270-
int Index = DpctGlobalInfo::getHelperFuncReplInfoIndexThenInc();
3271-
buildTempVariableMap(Index, CE, HelperFuncType::HFT_DefaultQueue);
3272-
std::string Str;
3273-
if (!DpctGlobalInfo::useEnqueueBarrier()) {
3274-
// ext_oneapi_submit_barrier is specified in the value of option
3275-
// --no-dpcpp-extensions.
3276-
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
3277-
3278-
Str = MapNames::getDpctNamespace() +
3279-
"get_current_device().queues_wait_and_throw();";
3280-
Str += getNL();
3281-
Str += getIndent(IndentLoc, SM).str();
3282-
std::string SubStr = "{{NEEDREPLACEQ" + std::to_string(Index) +
3283-
"}}.single_task([=](){});";
3284-
SubStr = "*" + ArgName + " = " + SubStr;
3285-
Str += SubStr;
3286-
3287-
Str += getNL();
3288-
Str += getIndent(IndentLoc, SM).str();
3289-
Str += MapNames::getDpctNamespace() +
3290-
"get_current_device().queues_wait_and_throw();";
3291-
Str += getNL();
3292-
Str += getIndent(IndentLoc, SM).str();
3293-
Str += "return 0;";
3294-
3295-
Str = "[&](){" + Str + "}()";
3296-
emplaceTransformation(new ReplaceStmt(CE, std::move(Str)));
3297-
return;
3298-
}
3299-
Str = "{{NEEDREPLACEQ" + std::to_string(Index) +
3300-
"}}.single_task([=](){})";
3301-
3302-
} else {
3303-
if (DpctGlobalInfo::useSYCLCompat()) {
3304-
report(CE->getBeginLoc(), Diagnostics::UNSUPPORT_SYCLCOMPAT, false,
3305-
"cudaEventRecord");
3306-
return;
3307-
}
3308-
std::string ReplaceStr;
3309-
ReplaceStr = MapNames::getDpctNamespace() + "sync_barrier";
3310-
emplaceTransformation(new ReplaceCalleeName(CE, std::move(ReplaceStr)));
3311-
emplaceTransformation(new InsertBeforeStmt(CE, MapNames::getCheckErrorMacroName() + "("));
3312-
emplaceTransformation(new InsertAfterStmt(CE, ")"));
3313-
report(CE->getBeginLoc(), Diagnostics::NOERROR_RETURN_ZERO, false);
3314-
return;
3315-
}
3316-
StmtStr = "*" + ArgName + " = " + Str;
3317-
} else {
3318-
std::string Str;
3319-
if (!DpctGlobalInfo::useEnqueueBarrier()) {
3320-
// ext_oneapi_submit_barrier is specified in the value of option
3321-
// --no-dpcpp-extensions.
3322-
3323-
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
3324-
3325-
Str = MapNames::getDpctNamespace() +
3326-
"get_current_device().queues_wait_and_throw();";
3327-
Str += getNL();
3328-
Str += getIndent(IndentLoc, SM).str();
3329-
Str += StreamName + "->" + "single_task([=](){});";
3330-
Str += getNL();
3331-
Str += getIndent(IndentLoc, SM).str();
3332-
Str += MapNames::getDpctNamespace() +
3333-
"get_current_device().queues_wait_and_throw(); return 0;";
3334-
3335-
Str = "[&](){" + Str + "}()";
3336-
emplaceTransformation(new ReplaceStmt(CE, std::move(Str)));
3337-
return;
3338-
}
3339-
Str = StreamName + "->" + "single_task([=](){})";
3340-
3341-
} else {
3342-
Str = StreamName + "->" + "ext_oneapi_submit_barrier()";
3343-
}
3344-
StmtStr = "*" + ArgName + " = " + Str;
3345-
}
3346-
StmtStr = MapNames::getCheckErrorMacroName() + "(" + StmtStr + ")";
3248+
if (isPlaceholderIdxDuplicated(CE))
3249+
return;
3250+
int Index = DpctGlobalInfo::getHelperFuncReplInfoIndexThenInc();
3251+
buildTempVariableMap(Index, CE, HelperFuncType::HFT_DefaultQueue);
33473252

3348-
emplaceTransformation(new ReplaceStmt(CE, std::move(StmtStr)));
3253+
if (DpctGlobalInfo::useSYCLCompat()) {
3254+
report(CE->getBeginLoc(), Diagnostics::UNSUPPORT_SYCLCOMPAT, false,
3255+
"cudaEventRecord");
3256+
return;
3257+
}
3258+
std::string ReplaceStr;
3259+
ReplaceStr = MapNames::getDpctNamespace() + "sync_barrier";
3260+
emplaceTransformation(new ReplaceCalleeName(CE, std::move(ReplaceStr)));
33493261

3262+
if (IsAssigned) {
3263+
emplaceTransformation(
3264+
new InsertBeforeStmt(CE, MapNames::getCheckErrorMacroName() + "("));
3265+
emplaceTransformation(new InsertAfterStmt(CE, ")"));
33503266
report(CE->getBeginLoc(), Diagnostics::NOERROR_RETURN_ZERO, false);
3351-
3352-
} else {
3353-
std::string ReplStr;
3354-
if (IsDefaultStream) {
3355-
if (isPlaceholderIdxDuplicated(CE))
3356-
return;
3357-
int Index = DpctGlobalInfo::getHelperFuncReplInfoIndexThenInc();
3358-
buildTempVariableMap(Index, CE, HelperFuncType::HFT_DefaultQueue);
3359-
std::string Str;
3360-
if (!DpctGlobalInfo::useEnqueueBarrier()) {
3361-
// ext_oneapi_submit_barrier is specified in the value of option
3362-
// --no-dpcpp-extensions.
3363-
3364-
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
3365-
3366-
Str = MapNames::getDpctNamespace() +
3367-
"get_current_device().queues_wait_and_throw();";
3368-
Str += getNL();
3369-
Str += getIndent(IndentLoc, SM).str();
3370-
Str += "*" + ArgName + " = {{NEEDREPLACEQ" + std::to_string(Index) +
3371-
"}}.single_task([=](){});";
3372-
Str += getNL();
3373-
Str += getIndent(IndentLoc, SM).str();
3374-
Str += MapNames::getDpctNamespace() +
3375-
"get_current_device().queues_wait_and_throw()";
3376-
3377-
} else {
3378-
Str = "*" + ArgName + " = {{NEEDREPLACEQ" + std::to_string(Index) +
3379-
"}}.single_task([=](){})";
3380-
}
3381-
3382-
} else {
3383-
if (DpctGlobalInfo::useSYCLCompat()) {
3384-
report(CE->getBeginLoc(), Diagnostics::UNSUPPORT_SYCLCOMPAT, false,
3385-
"cudaEventRecord");
3386-
return;
3387-
}
3388-
std::string ReplaceStr;
3389-
ReplaceStr = MapNames::getDpctNamespace() + "sync_barrier";
3390-
emplaceTransformation(new ReplaceCalleeName(CE, std::move(ReplaceStr)));
3391-
return;
3392-
}
3393-
ReplStr += Str;
3394-
} else {
3395-
3396-
std::string Str;
3397-
if (!DpctGlobalInfo::useEnqueueBarrier()) {
3398-
// ext_oneapi_submit_barrier is specified in the value of option
3399-
// --no-dpcpp-extensions.
3400-
3401-
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
3402-
3403-
Str = MapNames::getDpctNamespace() +
3404-
"get_current_device().queues_wait_and_throw();";
3405-
Str += getNL();
3406-
Str += getIndent(IndentLoc, SM).str();
3407-
3408-
Str += "*" + ArgName + " = " + StreamName + "->single_task([=](){});";
3409-
Str += getNL();
3410-
Str += getIndent(IndentLoc, SM).str();
3411-
Str += MapNames::getDpctNamespace() +
3412-
"get_current_device().queues_wait_and_throw()";
3413-
3414-
} else {
3415-
Str = "*" + ArgName + " = " + StreamName + "->single_task([=](){})";
3416-
}
3417-
3418-
} else {
3419-
if (DpctGlobalInfo::useSYCLCompat()) {
3420-
report(CE->getBeginLoc(), Diagnostics::UNSUPPORT_SYCLCOMPAT, false,
3421-
"cudaEventRecord");
3422-
return;
3423-
}
3424-
std::string ReplaceStr;
3425-
ReplaceStr = MapNames::getDpctNamespace() + "sync_barrier";
3426-
emplaceTransformation(new ReplaceCalleeName(CE, std::move(ReplaceStr)));
3427-
return;
3428-
}
3429-
ReplStr += Str;
3430-
}
3431-
3432-
emplaceTransformation(new ReplaceStmt(CE, std::move(ReplStr)));
3267+
return;
34333268
}
34343269
}
34353270

clang/test/dpct/disable-all-extensions.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,16 @@ int enqueued_barriers() {
121121
cudaEventCreate(&start);
122122
cudaEventCreate(&stop);
123123

124-
// CHECK: dpct::get_current_device().queues_wait_and_throw();
125-
// CHECK-NEXT: *start = q_ct1.single_task([=](){});
126-
// CHECK-NEXT: dpct::get_current_device().queues_wait_and_throw();
124+
// CHECK: dpct::sync_barrier(start, &q_ct1);
127125
cudaEventRecord(start, 0);
128126

129127
cudaMemcpyAsync(da, ha, N * sizeof(int), cudaMemcpyHostToDevice);
130128
cudaMemcpyAsync(da, ha, N * sizeof(int), cudaMemcpyHostToDevice, 0);
131129
cudaMemcpyAsync(da, ha, N * sizeof(int), cudaMemcpyHostToDevice, stream);
132130

133-
// CHECK: dpct::get_current_device().queues_wait_and_throw();
134-
// CHECK-NEXT: *stop = q_ct1.single_task([=](){});
135-
// CHECK-NEXT: dpct::get_current_device().queues_wait_and_throw();
136-
// CHECK-NEXT: stop->wait_and_throw();
137-
// CHECK-NEXT: elapsedTime = (stop->get_profiling_info<sycl::info::event_profiling::command_end>() - start->get_profiling_info<sycl::info::event_profiling::command_start>()) / 1000000.0f;
131+
// CHECK: dpct::sync_barrier(stop, &q_ct1);
132+
// CHECK-NEXT: stop->wait_and_throw();
133+
// CHECK-NEXT: elapsedTime = (stop->get_profiling_info<sycl::info::event_profiling::command_end>() - start->get_profiling_info<sycl::info::event_profiling::command_start>()) / 1000000.0f;
138134
cudaEventRecord(stop, 0);
139135
cudaEventSynchronize(stop);
140136
cudaEventElapsedTime(&elapsedTime, start, stop);

clang/test/dpct/event_record_with_flags_no_qb.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#include <cuda_runtime.h>
1313

1414
// CHECK: void cudaEventRecordWithFlags_1() {
15-
// CHECK-NEXT: dpct::event_ptr start;
16-
// CHECK-NEXT: dpct::queue_ptr stream;
17-
// CHECK-NEXT: *start = stream->single_task([=](){});
15+
// CHECK-NEXT: dpct::event_ptr start;
16+
// CHECK-NEXT: dpct::queue_ptr stream;
17+
// CHECK-NEXT: dpct::sync_barrier(start, stream);
1818
// CHECK-NEXT: }
1919
void cudaEventRecordWithFlags_1() {
2020
cudaEvent_t start;

clang/test/dpct/event_record_with_flags_qb.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
#include <cuda_runtime.h>
1212

1313
// CHECK: void cudaEventRecordWithFlags_1() {
14-
// CHECK-NEXT: dpct::event_ptr start;
15-
// CHECK-NEXT: dpct::queue_ptr stream;
16-
// CHECK-NEXT: dpct::sync_barrier(start, stream);
14+
// CHECK-NEXT: dpct::event_ptr start;
15+
// CHECK-NEXT: dpct::queue_ptr stream;
16+
// CHECK-NEXT: dpct::sync_barrier(start, stream);
1717
// CHECK-NEXT: }
1818
void cudaEventRecordWithFlags_1() {
1919
cudaEvent_t start;

clang/test/dpct/tm-complex-profiling.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ int foo_test_2()
257257
foo_kernel_3<<<grid, block, 0, streams[i]>>>();
258258
foo_kernel_4<<<grid, block, 0, streams[i]>>>();
259259

260-
// CHECK: CHECK(DPCT_CHECK_ERROR(*kernelEvent[i] = streams[i]->ext_oneapi_submit_barrier()));
260+
// CHECK: CHECK(DPCT_CHECK_ERROR(dpct::sync_barrier(kernelEvent[i], streams[i])));
261261
// CHECK-NEXT: streams[n_streams - 1]->ext_oneapi_submit_barrier({*kernelEvent[i]});
262262
CHECK(cudaEventRecord(kernelEvent[i], streams[i]));
263263
cudaStreamWaitEvent(streams[n_streams - 1], kernelEvent[i], 0);

0 commit comments

Comments
 (0)