@@ -415,13 +415,17 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
415
415
ptrs, events);
416
416
mem_free_thread.detach ();
417
417
#else
418
- std::int64_t m_int64 = n;
419
- std::int64_t n_int64 = n;
420
- std::int64_t lda_int64 = lda;
421
- std::int64_t group_sizes = batch_size;
418
+ std::int64_t *m_int64 = new std::int64_t ;
419
+ std::int64_t *n_int64 = new std::int64_t ;
420
+ std::int64_t *lda_int64 = new std::int64_t ;
421
+ std::int64_t *group_sizes = new std::int64_t ;
422
+ *m_int64 = n;
423
+ *n_int64 = n;
424
+ *lda_int64 = lda;
425
+ *group_sizes = batch_size;
422
426
std::int64_t scratchpad_size =
423
427
oneapi::mkl::lapack::getrf_batch_scratchpad_size<Ty>(
424
- exec_queue, & m_int64, & n_int64, & lda_int64, 1 , & group_sizes);
428
+ exec_queue, m_int64, n_int64, lda_int64, 1 , group_sizes);
425
429
426
430
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
427
431
std::int64_t *ipiv_int64 =
@@ -433,9 +437,9 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
433
437
for (std::int64_t i = 0 ; i < batch_size; ++i)
434
438
ipiv_int64_ptr[i] = ipiv_int64 + n * i;
435
439
436
- oneapi::mkl::lapack::getrf_batch (
437
- exec_queue, &m_int64, &n_int64, (Ty **)a_shared, & lda_int64,
438
- ipiv_int64_ptr, 1 , & group_sizes, scratchpad, scratchpad_size);
440
+ oneapi::mkl::lapack::getrf_batch (exec_queue, m_int64, n_int64,
441
+ (Ty **)a_shared, lda_int64, ipiv_int64_ptr ,
442
+ 1 , group_sizes, scratchpad, scratchpad_size);
439
443
440
444
sycl::event e = exec_queue.submit ([&](sycl::handler &cgh) {
441
445
cgh.parallel_for <
@@ -445,6 +449,15 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
445
449
});
446
450
});
447
451
452
+ exec_queue.submit ([&](sycl::handler &cgh) {
453
+ cgh.depends_on (e);
454
+ cgh.host_task ([=] {
455
+ delete m_int64;
456
+ delete n_int64;
457
+ delete lda_int64;
458
+ delete group_sizes;
459
+ });
460
+ });
448
461
std::vector<void *> ptrs{scratchpad, ipiv_int64, ipiv_int64_ptr, a_shared};
449
462
::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
450
463
#endif
@@ -535,15 +548,22 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
535
548
ptrs, events);
536
549
mem_free_thread.detach ();
537
550
#else
538
- std::int64_t n_int64 = n;
539
- std::int64_t nrhs_int64 = nrhs;
540
- std::int64_t lda_int64 = lda;
541
- std::int64_t ldb_int64 = ldb;
542
- std::int64_t group_sizes = batch_size;
551
+ std::int64_t *n_int64 = new std::int64_t ;
552
+ std::int64_t *nrhs_int64 = new std::int64_t ;
553
+ std::int64_t *lda_int64 = new std::int64_t ;
554
+ std::int64_t *ldb_int64 = new std::int64_t ;
555
+ std::int64_t *group_sizes = new std::int64_t ;
556
+ oneapi::mkl::transpose *trans_array = new oneapi::mkl::transpose;
557
+ *n_int64 = n;
558
+ *nrhs_int64 = nrhs;
559
+ *lda_int64 = lda;
560
+ *ldb_int64 = ldb;
561
+ *group_sizes = batch_size;
562
+ *trans_array = trans;
543
563
std::int64_t scratchpad_size =
544
564
oneapi::mkl::lapack::getrs_batch_scratchpad_size<Ty>(
545
- exec_queue, &trans, & n_int64, & nrhs_int64, & lda_int64, & ldb_int64, 1 ,
546
- & group_sizes);
565
+ exec_queue, trans_array, n_int64, nrhs_int64, lda_int64, ldb_int64, 1 ,
566
+ group_sizes);
547
567
548
568
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
549
569
std::int64_t *ipiv_int64 =
@@ -569,10 +589,21 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
569
589
ipiv_int64_ptr[i] = ipiv_int64 + n * i;
570
590
571
591
sycl::event e = oneapi::mkl::lapack::getrs_batch (
572
- exec_queue, &trans, & n_int64, & nrhs_int64, (Ty **)a_shared, & lda_int64,
573
- ipiv_int64_ptr, (Ty **)b_shared, & ldb_int64, 1 , & group_sizes, scratchpad,
592
+ exec_queue, trans_array, n_int64, nrhs_int64, (Ty **)a_shared, lda_int64,
593
+ ipiv_int64_ptr, (Ty **)b_shared, ldb_int64, 1 , group_sizes, scratchpad,
574
594
scratchpad_size);
575
595
596
+ exec_queue.submit ([&](sycl::handler &cgh) {
597
+ cgh.depends_on (e);
598
+ cgh.host_task ([=] {
599
+ delete n_int64;
600
+ delete nrhs_int64;
601
+ delete lda_int64;
602
+ delete ldb_int64;
603
+ delete group_sizes;
604
+ delete trans_array;
605
+ });
606
+ });
576
607
std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
577
608
b_shared};
578
609
::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
@@ -659,12 +690,15 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
659
690
ptrs, events);
660
691
mem_free_thread.detach ();
661
692
#else
662
- std::int64_t n_int64 = n;
663
- std::int64_t ldb_int64 = ldb;
664
- std::int64_t group_sizes = batch_size;
693
+ std::int64_t *n_int64 = new std::int64_t ;
694
+ std::int64_t *ldb_int64 = new std::int64_t ;
695
+ std::int64_t *group_sizes = new std::int64_t ;
696
+ *n_int64 = n;
697
+ *ldb_int64 = ldb;
698
+ *group_sizes = batch_size;
665
699
std::int64_t scratchpad_size =
666
700
oneapi::mkl::lapack::getri_batch_scratchpad_size<Ty>(
667
- exec_queue, & n_int64, & ldb_int64, 1 , & group_sizes);
701
+ exec_queue, n_int64, ldb_int64, 1 , group_sizes);
668
702
669
703
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
670
704
std::int64_t *ipiv_int64 =
@@ -695,9 +729,17 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
695
729
}
696
730
697
731
sycl::event e = oneapi::mkl::lapack::getri_batch (
698
- exec_queue, & n_int64, (Ty **)b_shared, & ldb_int64, ipiv_int64_ptr, 1 ,
699
- & group_sizes, scratchpad, scratchpad_size);
732
+ exec_queue, n_int64, (Ty **)b_shared, ldb_int64, ipiv_int64_ptr, 1 ,
733
+ group_sizes, scratchpad, scratchpad_size);
700
734
735
+ exec_queue.submit ([&](sycl::handler &cgh) {
736
+ cgh.depends_on (e);
737
+ cgh.host_task ([=] {
738
+ delete n_int64;
739
+ delete ldb_int64;
740
+ delete group_sizes;
741
+ });
742
+ });
701
743
std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
702
744
b_shared};
703
745
::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
@@ -780,13 +822,17 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
780
822
mem_free_thread_a.detach ();
781
823
mem_free_thread_tau.detach ();
782
824
#else
783
- std::int64_t m_int64 = n;
784
- std::int64_t n_int64 = n;
785
- std::int64_t lda_int64 = lda;
786
- std::int64_t group_sizes = batch_size;
825
+ std::int64_t *m_int64 = new std::int64_t ;
826
+ std::int64_t *n_int64 = new std::int64_t ;
827
+ std::int64_t *lda_int64 = new std::int64_t ;
828
+ std::int64_t *group_sizes = new std::int64_t ;
829
+ *m_int64 = n;
830
+ *n_int64 = n;
831
+ *lda_int64 = lda;
832
+ *group_sizes = batch_size;
787
833
std::int64_t scratchpad_size =
788
834
oneapi::mkl::lapack::geqrf_batch_scratchpad_size<Ty>(
789
- exec_queue, & m_int64, & n_int64, & lda_int64, 1 , & group_sizes);
835
+ exec_queue, m_int64, n_int64, lda_int64, 1 , group_sizes);
790
836
791
837
Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
792
838
T **a_shared = sycl::malloc_shared<T *>(batch_size, exec_queue);
@@ -795,9 +841,18 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
795
841
exec_queue.memcpy (tau_shared, tau, batch_size * sizeof (T *)).wait ();
796
842
797
843
sycl::event e = oneapi::mkl::lapack::geqrf_batch (
798
- exec_queue, & m_int64, & n_int64, (Ty **)a_shared, & lda_int64,
799
- (Ty **)tau_shared, 1 , & group_sizes, scratchpad, scratchpad_size);
844
+ exec_queue, m_int64, n_int64, (Ty **)a_shared, lda_int64,
845
+ (Ty **)tau_shared, 1 , group_sizes, scratchpad, scratchpad_size);
800
846
847
+ exec_queue.submit ([&](sycl::handler &cgh) {
848
+ cgh.depends_on (e);
849
+ cgh.host_task ([=] {
850
+ delete m_int64;
851
+ delete n_int64;
852
+ delete lda_int64;
853
+ delete group_sizes;
854
+ });
855
+ });
801
856
std::vector<void *> ptrs{scratchpad, a_shared, tau_shared};
802
857
::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
803
858
#endif
0 commit comments