Skip to content

Commit dead51d

Browse files
[SYCLomatic] Support the migration of cusparse<T>csrgemm2 (c=a*b+d) related API with helper functions (#2643)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com> Co-authored-by: Wang, Zhiming <zhiming.wang@intel.com>
1 parent 34e57a6 commit dead51d

File tree

7 files changed

+656
-30
lines changed

7 files changed

+656
-30
lines changed

clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ TYPE_REWRITE_ENTRY("cusparseSolvePolicy_t", TYPE_FACTORY(STR("int")))
441441
TYPE_REWRITE_ENTRY("cusparseAction_t",
442442
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
443443
"sparse::conversion_scope")))
444+
TYPE_REWRITE_ENTRY("csrgemm2Info_t",
445+
TYPE_FACTORY(STR("std::shared_ptr<" +
446+
MapNames::getLibraryHelperNamespace() +
447+
"sparse::csrgemm2_info>")))
444448

445449
TYPE_REWRITE_ENTRY(
446450
"cooperative_groups::__v1::coalesced_group",

clang/lib/DPCT/RulesMathLib/APINamesCUSPARSE.inc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,79 @@ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
556556
ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(8), ARG(9),
557557
ARG(10), ARG(11), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18),
558558
ARG(19))))
559+
560+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
561+
"cusparseScsrgemm2_bufferSizeExt",
562+
CALL(MapNames::getLibraryHelperNamespace() +
563+
"sparse::csrgemm2_get_buffer_size<float>",
564+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
565+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
566+
ARG(17), ARG(18), ARG(19))))
567+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
568+
"cusparseDcsrgemm2_bufferSizeExt",
569+
CALL(MapNames::getLibraryHelperNamespace() +
570+
"sparse::csrgemm2_get_buffer_size<double>",
571+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
572+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
573+
ARG(17), ARG(18), ARG(19))))
574+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
575+
"cusparseCcsrgemm2_bufferSizeExt",
576+
CALL(MapNames::getLibraryHelperNamespace() +
577+
"sparse::csrgemm2_get_buffer_size<" + MapNames::getClNamespace() +
578+
"float2>",
579+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
580+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
581+
ARG(17), ARG(18), ARG(19))))
582+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
583+
"cusparseZcsrgemm2_bufferSizeExt",
584+
CALL(MapNames::getLibraryHelperNamespace() +
585+
"sparse::csrgemm2_get_buffer_size<" + MapNames::getClNamespace() +
586+
"double2>",
587+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
588+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
589+
ARG(17), ARG(18), ARG(19))))
590+
591+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
592+
"cusparseXcsrgemm2Nnz",
593+
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrgemm2_nnz", ARG(0),
594+
ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9),
595+
ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17),
596+
ARG(18), ARG(19), ARG(20))))
597+
598+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
599+
"cusparseScsrgemm2",
600+
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrgemm2<float>",
601+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
602+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
603+
ARG(17), ARG(18), ARG(19), ARG(20), ARG(21), ARG(22), ARG(23), ARG(24),
604+
ARG(25), ARG(26))))
605+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
606+
"cusparseDcsrgemm2",
607+
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrgemm2<double>",
608+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
609+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
610+
ARG(17), ARG(18), ARG(19), ARG(20), ARG(21), ARG(22), ARG(23), ARG(24),
611+
ARG(25), ARG(26))))
612+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
613+
"cusparseCcsrgemm2",
614+
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrgemm2<" +
615+
MapNames::getClNamespace() + "float2>",
616+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
617+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
618+
ARG(17), ARG(18), ARG(19), ARG(20), ARG(21), ARG(22), ARG(23), ARG(24),
619+
ARG(25), ARG(26))))
620+
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
621+
"cusparseZcsrgemm2",
622+
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrgemm2<" +
623+
MapNames::getClNamespace() + "double2>",
624+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
625+
ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16),
626+
ARG(17), ARG(18), ARG(19), ARG(20), ARG(21), ARG(22), ARG(23), ARG(24),
627+
ARG(25), ARG(26))))
628+
629+
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY(
630+
"cusparseCreateCsrgemm2Info", DEREF(0),
631+
CALL("std::make_shared<" + MapNames::getLibraryHelperNamespace() +
632+
"sparse::csrgemm2_info>")))
633+
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cusparseDestroyCsrgemm2Info",
634+
ARG(0), false, "reset"))

clang/lib/DPCT/RulesMathLib/SpBLASAPIMigration.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace clang::ast_matchers;
2020
void SpBLASTypeLocRule::registerMatcher(ast_matchers::MatchFinder &MF) {
2121
auto TargetTypeName = [&]() {
2222
return hasAnyName("csrsv2Info_t", "cusparseSolvePolicy_t",
23-
"cusparseAction_t");
23+
"cusparseAction_t", "csrgemm2Info_t");
2424
};
2525

2626
MF.addMatcher(
@@ -56,7 +56,8 @@ void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
5656
"cusparseGetMatDiagType", "cusparseSetMatFillMode",
5757
"cusparseGetMatFillMode", "cusparseCreateSolveAnalysisInfo",
5858
"cusparseDestroySolveAnalysisInfo", "cusparseCreateCsrsv2Info",
59-
"cusparseDestroyCsrsv2Info",
59+
"cusparseDestroyCsrsv2Info", "cusparseCreateCsrgemm2Info",
60+
"cusparseDestroyCsrgemm2Info",
6061
/*level 2*/
6162
"cusparseScsrmv", "cusparseDcsrmv", "cusparseCcsrmv", "cusparseZcsrmv",
6263
"cusparseScsrmv_mp", "cusparseDcsrmv_mp", "cusparseCcsrmv_mp",
@@ -79,6 +80,10 @@ void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
7980
"cusparseScsrgemm", "cusparseDcsrgemm", "cusparseCcsrgemm",
8081
"cusparseZcsrgemm", "cusparseXcsrgemmNnz", "cusparseScsrmm2",
8182
"cusparseDcsrmm2", "cusparseCcsrmm2", "cusparseZcsrmm2",
83+
"cusparseScsrgemm2_bufferSizeExt", "cusparseDcsrgemm2_bufferSizeExt",
84+
"cusparseCcsrgemm2_bufferSizeExt", "cusparseZcsrgemm2_bufferSizeExt",
85+
"cusparseXcsrgemm2Nnz", "cusparseScsrgemm2", "cusparseDcsrgemm2",
86+
"cusparseCcsrgemm2", "cusparseZcsrgemm2",
8287
/*Generic*/
8388
"cusparseCreateCsr", "cusparseDestroySpMat", "cusparseCsrGet",
8489
"cusparseSpMatGetFormat", "cusparseSpMatGetIndexBase",
@@ -177,7 +182,7 @@ void SPBLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
177182
.bind("CallExpr"));
178183
auto CEResults = match(Matcher, *CS1, DpctGlobalInfo::getContext());
179184
// Find the correct call
180-
const CallExpr* CorrectCall = nullptr;
185+
const CallExpr *CorrectCall = nullptr;
181186
for (auto &Result : CEResults) {
182187
const CallExpr *MatchedCE = Result.getNodeAs<CallExpr>("CallExpr");
183188
if (MatchedCE) {
@@ -257,7 +262,6 @@ void SPBLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
257262
}
258263
}
259264

260-
261265
// Rule for spBLAS enums.
262266
// Migrate spBLAS status values to corresponding int values
263267
// Other spBLAS named values are migrated to corresponding named values

clang/lib/DPCT/SrcAPI/APINames_cuSPARSE.inc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ ENTRY(cusparseCreateBsric02Info, cusparseCreateBsric02Info, false, NO_FLAG, P4,
7070
ENTRY(cusparseDestroyBsric02Info, cusparseDestroyBsric02Info, false, NO_FLAG, P4, "comment")
7171
ENTRY(cusparseCreateBsrilu02Info, cusparseCreateBsrilu02Info, false, NO_FLAG, P4, "comment")
7272
ENTRY(cusparseDestroyBsrilu02Info, cusparseDestroyBsrilu02Info, false, NO_FLAG, P4, "comment")
73-
ENTRY(cusparseCreateCsrgemm2Info, cusparseCreateCsrgemm2Info, false, NO_FLAG, P4, "comment")
74-
ENTRY(cusparseDestroyCsrgemm2Info, cusparseDestroyCsrgemm2Info, false, NO_FLAG, P4, "comment")
73+
ENTRY(cusparseCreateCsrgemm2Info, cusparseCreateCsrgemm2Info, true, NO_FLAG, P4, "comment")
74+
ENTRY(cusparseDestroyCsrgemm2Info, cusparseDestroyCsrgemm2Info, true, NO_FLAG, P4, "comment")
7575
ENTRY(cusparseCreatePruneInfo, cusparseCreatePruneInfo, false, NO_FLAG, P4, "comment")
7676
ENTRY(cusparseDestroyPruneInfo, cusparseDestroyPruneInfo, false, NO_FLAG, P4, "comment")
7777

@@ -249,15 +249,15 @@ ENTRY(cusparseScsrgemm, cusparseScsrgemm, true, NO_FLAG, P4, "comment")
249249
ENTRY(cusparseDcsrgemm, cusparseDcsrgemm, true, NO_FLAG, P4, "comment")
250250
ENTRY(cusparseCcsrgemm, cusparseCcsrgemm, true, NO_FLAG, P4, "comment")
251251
ENTRY(cusparseZcsrgemm, cusparseZcsrgemm, true, NO_FLAG, P4, "comment")
252-
ENTRY(cusparseScsrgemm2_bufferSizeExt, cusparseScsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
253-
ENTRY(cusparseDcsrgemm2_bufferSizeExt, cusparseDcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
254-
ENTRY(cusparseCcsrgemm2_bufferSizeExt, cusparseCcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
255-
ENTRY(cusparseZcsrgemm2_bufferSizeExt, cusparseZcsrgemm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
256-
ENTRY(cusparseXcsrgemm2Nnz, cusparseXcsrgemm2Nnz, false, NO_FLAG, P4, "comment")
257-
ENTRY(cusparseScsrgemm2, cusparseScsrgemm2, false, NO_FLAG, P4, "comment")
258-
ENTRY(cusparseDcsrgemm2, cusparseDcsrgemm2, false, NO_FLAG, P4, "comment")
259-
ENTRY(cusparseCcsrgemm2, cusparseCcsrgemm2, false, NO_FLAG, P4, "comment")
260-
ENTRY(cusparseZcsrgemm2, cusparseZcsrgemm2, false, NO_FLAG, P4, "comment")
252+
ENTRY(cusparseScsrgemm2_bufferSizeExt, cusparseScsrgemm2_bufferSizeExt, true, NO_FLAG, P4, "comment")
253+
ENTRY(cusparseDcsrgemm2_bufferSizeExt, cusparseDcsrgemm2_bufferSizeExt, true, NO_FLAG, P4, "comment")
254+
ENTRY(cusparseCcsrgemm2_bufferSizeExt, cusparseCcsrgemm2_bufferSizeExt, true, NO_FLAG, P4, "comment")
255+
ENTRY(cusparseZcsrgemm2_bufferSizeExt, cusparseZcsrgemm2_bufferSizeExt, true, NO_FLAG, P4, "comment")
256+
ENTRY(cusparseXcsrgemm2Nnz, cusparseXcsrgemm2Nnz, true, NO_FLAG, P4, "comment")
257+
ENTRY(cusparseScsrgemm2, cusparseScsrgemm2, true, NO_FLAG, P4, "comment")
258+
ENTRY(cusparseDcsrgemm2, cusparseDcsrgemm2, true, NO_FLAG, P4, "comment")
259+
ENTRY(cusparseCcsrgemm2, cusparseCcsrgemm2, true, NO_FLAG, P4, "comment")
260+
ENTRY(cusparseZcsrgemm2, cusparseZcsrgemm2, true, NO_FLAG, P4, "comment")
261261

262262
// preconditioner
263263
ENTRY(cusparseScsric0, cusparseScsric0, false, NO_FLAG, P4, "comment")

clang/runtime/dpct-rt/include/dpct/detail/sparse_utils_detail.hpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct csrgemm_args_info_hash {
3838
return std::hash<std::string>{}(ss.str());
3939
}
4040
};
41+
4142
#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this.
4243
template <typename handle_t> class handle_manager {
4344
public:
@@ -47,18 +48,24 @@ template <typename handle_t> class handle_manager {
4748
~handle_manager() {
4849
if (!_q || !_h)
4950
return;
51+
release();
52+
}
53+
void init(sycl::queue *q) {
54+
_q = q;
55+
_h = new handle_t;
56+
_init_func(*_q, _h);
57+
}
58+
sycl::event release() {
59+
if (!_q || !_h)
60+
return sycl::event();
5061
sycl::event e = _rel_func(*_q, _h, _deps);
51-
_q->submit([&](sycl::handler &cgh) {
62+
sycl::event ret = _q->submit([&](sycl::handler &cgh) {
5263
cgh.depends_on(e);
5364
cgh.host_task([_hh = _h] { delete _hh; });
5465
});
5566
_h = nullptr;
5667
_q = nullptr;
57-
}
58-
void init(sycl::queue *q) {
59-
_q = q;
60-
_h = new handle_t;
61-
_init_func(_h);
68+
return ret;
6269
}
6370
handle_t &get_handle() { return *_h; }
6471
void add_dependency(sycl::event e) { _deps.push_back(e); }
@@ -68,7 +75,7 @@ template <typename handle_t> class handle_manager {
6875
sycl::queue *_q = nullptr;
6976

7077
private:
71-
using init_func_t = std::function<void(handle_t *)>;
78+
using init_func_t = std::function<void(sycl::queue &, handle_t *)>;
7279
using rel_func_t = std::function<sycl::event(
7380
sycl::queue &, handle_t *, const std::vector<sycl::event> &dependencies)>;
7481
handle_t *_h = nullptr;
@@ -79,16 +86,33 @@ template <typename handle_t> class handle_manager {
7986
template <>
8087
inline handle_manager<oneapi::mkl::sparse::matrix_handle_t>::init_func_t
8188
handle_manager<oneapi::mkl::sparse::matrix_handle_t>::_init_func =
82-
oneapi::mkl::sparse::init_matrix_handle;
89+
[](sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_desc) {
90+
oneapi::mkl::sparse::init_matrix_handle(p_desc);
91+
};
8392
template <>
8493
inline handle_manager<oneapi::mkl::sparse::matrix_handle_t>::rel_func_t
8594
handle_manager<oneapi::mkl::sparse::matrix_handle_t>::_rel_func =
8695
oneapi::mkl::sparse::release_matrix_handle;
8796

97+
template <>
98+
inline handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::init_func_t
99+
handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::_init_func =
100+
oneapi::mkl::sparse::init_omatadd_descr;
101+
template <>
102+
inline handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::rel_func_t
103+
handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::_rel_func =
104+
[](sycl::queue &queue, oneapi::mkl::sparse::omatadd_descr_t *p_desc,
105+
const std::vector<sycl::event> &dependencies) -> sycl::event {
106+
return oneapi::mkl::sparse::release_omatadd_descr(queue, *p_desc,
107+
dependencies);
108+
};
109+
88110
template <>
89111
inline handle_manager<oneapi::mkl::sparse::matmat_descr_t>::init_func_t
90112
handle_manager<oneapi::mkl::sparse::matmat_descr_t>::_init_func =
91-
oneapi::mkl::sparse::init_matmat_descr;
113+
[](sycl::queue &queue, oneapi::mkl::sparse::matmat_descr_t *p_desc) {
114+
oneapi::mkl::sparse::init_matmat_descr(p_desc);
115+
};
92116
template <>
93117
inline handle_manager<oneapi::mkl::sparse::matmat_descr_t>::rel_func_t
94118
handle_manager<oneapi::mkl::sparse::matmat_descr_t>::_rel_func =

0 commit comments

Comments
 (0)