Skip to content

Commit 8513308

Browse files
authored
[SYCLomatic] Fix bugs in some blas helper functions (#2796)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 6d208a4 commit 8513308

File tree

1 file changed

+85
-30
lines changed

1 file changed

+85
-30
lines changed

clang/runtime/dpct-rt/include/dpct/blas_utils.hpp

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,17 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
415415
ptrs, events);
416416
mem_free_thread.detach();
417417
#else
418-
std::int64_t m_int64 = n;
419-
std::int64_t n_int64 = n;
420-
std::int64_t lda_int64 = lda;
421-
std::int64_t group_sizes = batch_size;
418+
std::int64_t *m_int64 = new std::int64_t;
419+
std::int64_t *n_int64 = new std::int64_t;
420+
std::int64_t *lda_int64 = new std::int64_t;
421+
std::int64_t *group_sizes = new std::int64_t;
422+
*m_int64 = n;
423+
*n_int64 = n;
424+
*lda_int64 = lda;
425+
*group_sizes = batch_size;
422426
std::int64_t scratchpad_size =
423427
oneapi::mkl::lapack::getrf_batch_scratchpad_size<Ty>(
424-
exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes);
428+
exec_queue, m_int64, n_int64, lda_int64, 1, group_sizes);
425429

426430
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
427431
std::int64_t *ipiv_int64 =
@@ -433,9 +437,9 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
433437
for (std::int64_t i = 0; i < batch_size; ++i)
434438
ipiv_int64_ptr[i] = ipiv_int64 + n * i;
435439

436-
oneapi::mkl::lapack::getrf_batch(
437-
exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64,
438-
ipiv_int64_ptr, 1, &group_sizes, scratchpad, scratchpad_size);
440+
oneapi::mkl::lapack::getrf_batch(exec_queue, m_int64, n_int64,
441+
(Ty **)a_shared, lda_int64, ipiv_int64_ptr,
442+
1, group_sizes, scratchpad, scratchpad_size);
439443

440444
sycl::event e = exec_queue.submit([&](sycl::handler &cgh) {
441445
cgh.parallel_for<
@@ -445,6 +449,15 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
445449
});
446450
});
447451

452+
exec_queue.submit([&](sycl::handler &cgh) {
453+
cgh.depends_on(e);
454+
cgh.host_task([=] {
455+
delete m_int64;
456+
delete n_int64;
457+
delete lda_int64;
458+
delete group_sizes;
459+
});
460+
});
448461
std::vector<void *> ptrs{scratchpad, ipiv_int64, ipiv_int64_ptr, a_shared};
449462
::dpct::cs::enqueue_free(ptrs, {e}, exec_queue);
450463
#endif
@@ -535,15 +548,22 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
535548
ptrs, events);
536549
mem_free_thread.detach();
537550
#else
538-
std::int64_t n_int64 = n;
539-
std::int64_t nrhs_int64 = nrhs;
540-
std::int64_t lda_int64 = lda;
541-
std::int64_t ldb_int64 = ldb;
542-
std::int64_t group_sizes = batch_size;
551+
std::int64_t *n_int64 = new std::int64_t;
552+
std::int64_t *nrhs_int64 = new std::int64_t;
553+
std::int64_t *lda_int64 = new std::int64_t;
554+
std::int64_t *ldb_int64 = new std::int64_t;
555+
std::int64_t *group_sizes = new std::int64_t;
556+
oneapi::mkl::transpose *trans_array = new oneapi::mkl::transpose;
557+
*n_int64 = n;
558+
*nrhs_int64 = nrhs;
559+
*lda_int64 = lda;
560+
*ldb_int64 = ldb;
561+
*group_sizes = batch_size;
562+
*trans_array = trans;
543563
std::int64_t scratchpad_size =
544564
oneapi::mkl::lapack::getrs_batch_scratchpad_size<Ty>(
545-
exec_queue, &trans, &n_int64, &nrhs_int64, &lda_int64, &ldb_int64, 1,
546-
&group_sizes);
565+
exec_queue, trans_array, n_int64, nrhs_int64, lda_int64, ldb_int64, 1,
566+
group_sizes);
547567

548568
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
549569
std::int64_t *ipiv_int64 =
@@ -569,10 +589,21 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
569589
ipiv_int64_ptr[i] = ipiv_int64 + n * i;
570590

571591
sycl::event e = oneapi::mkl::lapack::getrs_batch(
572-
exec_queue, &trans, &n_int64, &nrhs_int64, (Ty **)a_shared, &lda_int64,
573-
ipiv_int64_ptr, (Ty **)b_shared, &ldb_int64, 1, &group_sizes, scratchpad,
592+
exec_queue, trans_array, n_int64, nrhs_int64, (Ty **)a_shared, lda_int64,
593+
ipiv_int64_ptr, (Ty **)b_shared, ldb_int64, 1, group_sizes, scratchpad,
574594
scratchpad_size);
575595

596+
exec_queue.submit([&](sycl::handler &cgh) {
597+
cgh.depends_on(e);
598+
cgh.host_task([=] {
599+
delete n_int64;
600+
delete nrhs_int64;
601+
delete lda_int64;
602+
delete ldb_int64;
603+
delete group_sizes;
604+
delete trans_array;
605+
});
606+
});
576607
std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
577608
b_shared};
578609
::dpct::cs::enqueue_free(ptrs, {e}, exec_queue);
@@ -659,12 +690,15 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
659690
ptrs, events);
660691
mem_free_thread.detach();
661692
#else
662-
std::int64_t n_int64 = n;
663-
std::int64_t ldb_int64 = ldb;
664-
std::int64_t group_sizes = batch_size;
693+
std::int64_t *n_int64 = new std::int64_t;
694+
std::int64_t *ldb_int64 = new std::int64_t;
695+
std::int64_t *group_sizes = new std::int64_t;
696+
*n_int64 = n;
697+
*ldb_int64 = ldb;
698+
*group_sizes = batch_size;
665699
std::int64_t scratchpad_size =
666700
oneapi::mkl::lapack::getri_batch_scratchpad_size<Ty>(
667-
exec_queue, &n_int64, &ldb_int64, 1, &group_sizes);
701+
exec_queue, n_int64, ldb_int64, 1, group_sizes);
668702

669703
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
670704
std::int64_t *ipiv_int64 =
@@ -695,9 +729,17 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
695729
}
696730

697731
sycl::event e = oneapi::mkl::lapack::getri_batch(
698-
exec_queue, &n_int64, (Ty **)b_shared, &ldb_int64, ipiv_int64_ptr, 1,
699-
&group_sizes, scratchpad, scratchpad_size);
732+
exec_queue, n_int64, (Ty **)b_shared, ldb_int64, ipiv_int64_ptr, 1,
733+
group_sizes, scratchpad, scratchpad_size);
700734

735+
exec_queue.submit([&](sycl::handler &cgh) {
736+
cgh.depends_on(e);
737+
cgh.host_task([=] {
738+
delete n_int64;
739+
delete ldb_int64;
740+
delete group_sizes;
741+
});
742+
});
701743
std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
702744
b_shared};
703745
::dpct::cs::enqueue_free(ptrs, {e}, exec_queue);
@@ -780,13 +822,17 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
780822
mem_free_thread_a.detach();
781823
mem_free_thread_tau.detach();
782824
#else
783-
std::int64_t m_int64 = n;
784-
std::int64_t n_int64 = n;
785-
std::int64_t lda_int64 = lda;
786-
std::int64_t group_sizes = batch_size;
825+
std::int64_t *m_int64 = new std::int64_t;
826+
std::int64_t *n_int64 = new std::int64_t;
827+
std::int64_t *lda_int64 = new std::int64_t;
828+
std::int64_t *group_sizes = new std::int64_t;
829+
*m_int64 = n;
830+
*n_int64 = n;
831+
*lda_int64 = lda;
832+
*group_sizes = batch_size;
787833
std::int64_t scratchpad_size =
788834
oneapi::mkl::lapack::geqrf_batch_scratchpad_size<Ty>(
789-
exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes);
835+
exec_queue, m_int64, n_int64, lda_int64, 1, group_sizes);
790836

791837
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
792838
T **a_shared = sycl::malloc_shared<T *>(batch_size, exec_queue);
@@ -795,9 +841,18 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
795841
exec_queue.memcpy(tau_shared, tau, batch_size * sizeof(T *)).wait();
796842

797843
sycl::event e = oneapi::mkl::lapack::geqrf_batch(
798-
exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64,
799-
(Ty **)tau_shared, 1, &group_sizes, scratchpad, scratchpad_size);
844+
exec_queue, m_int64, n_int64, (Ty **)a_shared, lda_int64,
845+
(Ty **)tau_shared, 1, group_sizes, scratchpad, scratchpad_size);
800846

847+
exec_queue.submit([&](sycl::handler &cgh) {
848+
cgh.depends_on(e);
849+
cgh.host_task([=] {
850+
delete m_int64;
851+
delete n_int64;
852+
delete lda_int64;
853+
delete group_sizes;
854+
});
855+
});
801856
std::vector<void *> ptrs{scratchpad, a_shared, tau_shared};
802857
::dpct::cs::enqueue_free(ptrs, {e}, exec_queue);
803858
#endif

0 commit comments

Comments
 (0)