Skip to content

Commit d5dc0f1

Browse files
authored
[SYCLomatic] Fix the index base of the result of legacy cublasI<t>amin and cublasI<t>amax (#2815)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent c03712d commit d5dc0f1

File tree

5 files changed

+37
-23
lines changed

5 files changed

+37
-23
lines changed

clang/lib/DPCT/RulesMathLib/BLASAPIMigration.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,13 +723,20 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
723723
ReplInfo.BufferTypeInfo[ReplInfo.BufferTypeInfo.size() - 1];
724724
std::string ReturnValueParamsStr;
725725
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_Restricted) {
726+
CallExprReplStr = CallExprReplStr + ", " + ResultTempPtr;
727+
if (FuncName == "cublasIsamax" || FuncName == "cublasIdamax" ||
728+
FuncName == "cublasIcamax" || FuncName == "cublasIzamax" ||
729+
FuncName == "cublasIsamin" || FuncName == "cublasIdamin" ||
730+
FuncName == "cublasIcamin" || FuncName == "cublasIzamin") {
731+
CallExprReplStr = CallExprReplStr + ", oneapi::mkl::index_base::one";
732+
}
726733
requestFeature(HelperFeatureEnum::device_ext);
727734
auto DefaultQueue = DpctGlobalInfo::getDefaultQueue(CE);
728735
PrefixInsertStr = PrefixInsertStr + ResultType + "* " + ResultTempPtr +
729736
" = " + MapNames::getClNamespace() +
730-
"malloc_shared<" + ResultType + ">(1, " + DefaultQueue + ");" +
731-
getNL() + IndentStr + CallExprReplStr + ", " +
732-
ResultTempPtr + ").wait();" + getNL() + IndentStr;
737+
"malloc_shared<" + ResultType + ">(1, " +
738+
DefaultQueue + ");" + getNL() + IndentStr +
739+
CallExprReplStr + ").wait();" + getNL() + IndentStr;
733740

734741
ReturnValueParamsStr =
735742
"(" + ResultTempPtr + "->real(), " + ResultTempPtr + "->imag())";
@@ -748,11 +755,18 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
748755
ResultTempPtr + ", " + DefaultQueue + ");";
749756
}
750757
} else {
758+
CallExprReplStr = CallExprReplStr + ", " + ResultTempBuf;
759+
if (FuncName == "cublasIsamax" || FuncName == "cublasIdamax" ||
760+
FuncName == "cublasIcamax" || FuncName == "cublasIzamax" ||
761+
FuncName == "cublasIsamin" || FuncName == "cublasIdamin" ||
762+
FuncName == "cublasIcamin" || FuncName == "cublasIzamin") {
763+
CallExprReplStr = CallExprReplStr + ", oneapi::mkl::index_base::one";
764+
}
751765
PrefixInsertStr = PrefixInsertStr + MapNames::getClNamespace() +
752766
"buffer<" + ResultType + "> " + ResultTempBuf + "(" +
753767
MapNames::getClNamespace() + "range<1>(1));" +
754-
getNL() + IndentStr + CallExprReplStr + ", " +
755-
ResultTempBuf + ");" + getNL() + IndentStr;
768+
getNL() + IndentStr + CallExprReplStr + ");" +
769+
getNL() + IndentStr;
756770
ReturnValueParamsStr =
757771
"(" + ResultTempBuf + ".get_host_access(" +
758772
MapNames::getClNamespace() + "read_only)[0].real(), " +

clang/test/dpct/cublas-usm-legacy.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,30 +66,30 @@ int main() {
6666

6767
// CHECK: int res;
6868
// CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
69-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
69+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
7070
// CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}};
7171
// CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
7272
int res = cublasIsamax(n, x_S, incx);
7373
// CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
74-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
74+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
7575
// CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}};
7676
// CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
7777
res = cublasIdamax(n, x_D, incx);
7878
// CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
79-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<float>*)x_C, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
79+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<float>*)x_C, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
8080
// CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}};
8181
// CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
8282
res = cublasIcamax(n, x_C, incx);
8383
// CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
84-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
84+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
8585
// CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}};
8686
// CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
8787
res = cublasIzamax(n, x_Z, incx);
8888

8989
// Because the return value of origin API is the result value, not the status, so keep using lambda here.
9090
// CHECK: if([&](){
9191
// CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
92-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
92+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
9393
// CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}};
9494
// CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
9595
// CHECK-NEXT: return res_temp_val_ct{{[0-9]+}};
@@ -98,7 +98,7 @@ int main() {
9898

9999
// CHECK: if(0!=[&](){
100100
// CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, q_ct1);
101-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
101+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
102102
// CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}};
103103
// CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1);
104104
// CHECK-NEXT: return res_temp_val_ct{{[0-9]+}};
@@ -233,7 +233,7 @@ int main() {
233233
//CHECK:int foo(){
234234
//CHECK-NEXT: return [&](){
235235
//CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared<int64_t>(1, dpct::get_in_order_queue());
236-
//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait();
236+
//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex<double>*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait();
237237
//CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}};
238238
//CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, dpct::get_in_order_queue());
239239
//CHECK-NEXT: return res_temp_val_ct{{[0-9]+}};

clang/test/dpct/cublasLegacyCZ.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ int main() {
6060
// CHECK-NEXT: {
6161
// CHECK-NEXT: auto x_C_buf_ct{{[0-9]+}} = dpct::get_buffer<std::complex<float>>(x_C);
6262
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
63-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
63+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
6464
// CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
6565
// CHECK-NEXT: }
6666
int res = cublasIcamax(n, x_C, incx);
6767

6868
// CHECK: {
6969
// CHECK-NEXT: auto x_Z_buf_ct{{[0-9]+}} = dpct::get_buffer<std::complex<double>>(x_Z);
7070
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
71-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
71+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
7272
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
7373
// CHECK-NEXT: }
7474
*result = cublasIzamax(n, x_Z, incx);
@@ -77,15 +77,15 @@ int main() {
7777
// CHECK: {
7878
// CHECK-NEXT: auto x_C_buf_ct{{[0-9]+}} = dpct::get_buffer<std::complex<float>>(x_C);
7979
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
80-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
80+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
8181
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
8282
// CHECK-NEXT: }
8383
*result = cublasIcamin(n, x_C, incx);
8484

8585
// CHECK: {
8686
// CHECK-NEXT: auto x_Z_buf_ct{{[0-9]+}} = dpct::get_buffer<std::complex<double>>(x_Z);
8787
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
88-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
88+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
8989
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
9090
// CHECK-NEXT: }
9191
*result = cublasIzamin(n, x_Z, incx);

clang/test/dpct/cublasLegacyLv123.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ int main() {
5050
// CHECK-NEXT: {
5151
// CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer<float>(x_S);
5252
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
53-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
53+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
5454
// CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
5555
// CHECK-NEXT: }
5656
int res = cublasIsamax(n, x_S, incx);
5757

5858
// CHECK: {
5959
// CHECK-NEXT: auto x_D_buf_ct{{[0-9]+}} = dpct::get_buffer<double>(x_D);
6060
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
61-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
61+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
6262
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
6363
// CHECK-NEXT: }
6464
*result = cublasIdamax(n, x_D, incx);
@@ -67,15 +67,15 @@ int main() {
6767
// CHECK: {
6868
// CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer<float>(x_S);
6969
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
70-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
70+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
7171
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
7272
// CHECK-NEXT: }
7373
*result = cublasIsamin(n, x_S, incx);
7474

7575
// CHECK: {
7676
// CHECK-NEXT: auto x_D_buf_ct{{[0-9]+}} = dpct::get_buffer<double>(x_D);
7777
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
78-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
78+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
7979
// CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
8080
// CHECK-NEXT: }
8181
*result = cublasIdamin(n, x_D, incx);
@@ -627,7 +627,7 @@ int main() {
627627
// CHECK: for(int i = [&](){
628628
// CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer<float>(x_S);
629629
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
630-
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
630+
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
631631
// CHECK-NEXT: return res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
632632
// CHECK-NEXT: }();;){}
633633
for(int i = cublasIsamax(n, x_S, incx);;){}
@@ -640,7 +640,7 @@ int main() {
640640
//CHECK-NEXT: return [&](){
641641
//CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer<float>(x_S);
642642
//CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
643-
//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}});
643+
//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
644644
//CHECK-NEXT: return res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
645645
//CHECK-NEXT: }();
646646
//CHECK-NEXT:}

clang/test/dpct/error-handling.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ void foo12() {
705705
// CHECK-NEXT: sycl::buffer<int64_t> res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1));
706706
// CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(
707707
// CHECK-NEXT: dpct::blas::descriptor::get_saved_queue(), 10, ct_0_buf_ct{{[0-9]+}}, 0,
708-
// CHECK-NEXT: res_temp_buf_ct{{[0-9]+}});
708+
// CHECK-NEXT: res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one);
709709
// CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0];
710710
// CHECK-NEXT: }
711711
// CHECK-NEXT: }

0 commit comments

Comments
 (0)