@@ -554,54 +554,55 @@ extern "C" int onemklZtrsmBatched(syclQueue_t device_queue, onemklSide left_righ
554
554
555
555
extern " C" int onemklHgemmBatchStrided (syclQueue_t device_queue, onemklTranspose transa,
556
556
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
557
- uint16_t alpha, const short *a, int64_t lda, int64_t stridea,
558
- const short *b, int64_t ldb, int64_t strideb, uint16_t beta,
557
+ uint16_t * alpha, const short *a, int64_t lda, int64_t stridea,
558
+ const short *b, int64_t ldb, int64_t strideb, uint16_t * beta,
559
559
short *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
560
560
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val , convert (transa),
561
- convert (transb), m, n, k, sycl::bit_cast<sycl::half>(alpha),
561
+ convert (transb), m, n, k,
562
+ *reinterpret_cast <const sycl::half *>(alpha),
562
563
reinterpret_cast <const sycl::half *>(a), lda, stridea,
563
564
reinterpret_cast <const sycl::half *>(b), ldb, strideb,
564
- sycl::bit_cast< sycl::half>(beta),
565
+ * reinterpret_cast < const sycl::half * >(beta),
565
566
reinterpret_cast <sycl::half *>(c), ldc, stridec, batch_size, {});
566
567
__FORCE_MKL_FLUSH__ (status);
567
568
return 0 ;
568
569
}
569
570
570
571
extern " C" int onemklSgemmBatchStrided (syclQueue_t device_queue, onemklTranspose transa,
571
572
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
572
- float alpha, const float *a, int64_t lda, int64_t stridea,
573
- const float *b, int64_t ldb, int64_t strideb, float beta,
573
+ float * alpha, const float *a, int64_t lda, int64_t stridea,
574
+ const float *b, int64_t ldb, int64_t strideb, float * beta,
574
575
float *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
575
576
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val , convert (transa),
576
- convert (transb), m, n, k, alpha, a, lda, stridea,
577
- b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
577
+ convert (transb), m, n, k, * alpha, a, lda, stridea,
578
+ b, ldb, strideb, * beta, c, ldc, stridec, batch_size, {});
578
579
__FORCE_MKL_FLUSH__ (status);
579
580
return 0 ;
580
581
}
581
582
582
583
extern " C" int onemklDgemmBatchStrided (syclQueue_t device_queue, onemklTranspose transa,
583
584
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
584
- double alpha, const double *a, int64_t lda, int64_t stridea,
585
- const double *b, int64_t ldb, int64_t strideb, double beta,
585
+ double * alpha, const double *a, int64_t lda, int64_t stridea,
586
+ const double *b, int64_t ldb, int64_t strideb, double * beta,
586
587
double *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
587
588
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val , convert (transa),
588
- convert (transb), m, n, k, alpha, a, lda, stridea,
589
- b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
589
+ convert (transb), m, n, k, * alpha, a, lda, stridea,
590
+ b, ldb, strideb, * beta, c, ldc, stridec, batch_size, {});
590
591
__FORCE_MKL_FLUSH__ (status);
591
592
return 0 ;
592
593
}
593
594
594
595
extern " C" int onemklCgemmBatchStrided (syclQueue_t device_queue, onemklTranspose transa,
595
596
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
596
- float _Complex alpha, const float _Complex *a, int64_t lda, int64_t stridea,
597
- const float _Complex *b, int64_t ldb, int64_t strideb, float _Complex beta,
597
+ float _Complex * alpha, const float _Complex *a, int64_t lda, int64_t stridea,
598
+ const float _Complex *b, int64_t ldb, int64_t strideb, float _Complex * beta,
598
599
float _Complex *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
599
600
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val , convert (transa),
600
- convert (transb), m, n, k, alpha,
601
+ convert (transb), m, n, k, * alpha,
601
602
reinterpret_cast <const std::complex<float > *>(a),
602
603
lda, stridea,
603
604
reinterpret_cast <const std::complex<float > *>(b),
604
- ldb, strideb, beta,
605
+ ldb, strideb, * beta,
605
606
reinterpret_cast <std::complex<float > *>(c),
606
607
ldc, stridec, batch_size, {});
607
608
__FORCE_MKL_FLUSH__ (status);
@@ -610,15 +611,15 @@ extern "C" int onemklCgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
610
611
611
612
extern " C" int onemklZgemmBatchStrided (syclQueue_t device_queue, onemklTranspose transa,
612
613
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
613
- double _Complex alpha, const double _Complex *a, int64_t lda, int64_t stridea,
614
- const double _Complex *b, int64_t ldb, int64_t strideb, double _Complex beta,
614
+ double _Complex * alpha, const double _Complex *a, int64_t lda, int64_t stridea,
615
+ const double _Complex *b, int64_t ldb, int64_t strideb, double _Complex * beta,
615
616
double _Complex *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
616
617
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val , convert (transa),
617
- convert (transb), m, n, k, alpha,
618
+ convert (transb), m, n, k, * alpha,
618
619
reinterpret_cast <const std::complex<double > *>(a),
619
620
lda, stridea,
620
621
reinterpret_cast <const std::complex<double > *>(b),
621
- ldb, strideb, beta,
622
+ ldb, strideb, * beta,
622
623
reinterpret_cast <std::complex<double > *>(c),
623
624
ldc, stridec, batch_size, {});
624
625
__FORCE_MKL_FLUSH__ (status);
@@ -627,14 +628,15 @@ extern "C" int onemklZgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
627
628
628
629
extern " C" int onemklHgemm (syclQueue_t device_queue, onemklTranspose transA,
629
630
onemklTranspose transB, int64_t m, int64_t n,
630
- int64_t k, uint16_t alpha, const short *A, int64_t lda,
631
- const short *B, int64_t ldb, uint16_t beta, short *C,
631
+ int64_t k, uint16_t * alpha, const short *A, int64_t lda,
632
+ const short *B, int64_t ldb, uint16_t * beta, short *C,
632
633
int64_t ldc) {
633
634
auto status = oneapi::mkl::blas::column_major::gemm (device_queue->val , convert (transA),
634
- convert (transB), m, n, k, sycl::bit_cast<sycl::half>(alpha),
635
+ convert (transB), m, n, k,
636
+ *reinterpret_cast <const sycl::half *>(alpha),
635
637
reinterpret_cast <const sycl::half *>(A), lda,
636
638
reinterpret_cast <const sycl::half *>(B), ldb,
637
- sycl::bit_cast< sycl::half>(beta),
639
+ * reinterpret_cast < const sycl::half * >(beta),
638
640
reinterpret_cast <sycl::half *>(C), ldc, {});
639
641
__FORCE_MKL_FLUSH__ (status);
640
642
return 0 ;
@@ -651,19 +653,20 @@ extern "C" int onemklHdot(syclQueue_t device_queue, int64_t n,
651
653
return 0 ;
652
654
}
653
655
654
- extern " C" int onemklHaxpy (syclQueue_t device_queue, int64_t n, uint16_t alpha,
656
+ extern " C" int onemklHaxpy (syclQueue_t device_queue, int64_t n, uint16_t * alpha,
655
657
const short *x, std::int64_t incx, short *y, int64_t incy) {
656
658
auto status = oneapi::mkl::blas::column_major::axpy (device_queue->val , n,
657
- sycl::bit_cast< sycl::half>(alpha),
659
+ * reinterpret_cast < const sycl::half * >(alpha),
658
660
reinterpret_cast <const sycl::half *>(x),
659
661
incx, reinterpret_cast <sycl::half *>(y), incy, {});
660
662
__FORCE_MKL_FLUSH__ (status);
661
663
return 0 ;
662
664
}
663
665
664
- extern " C" int onemklHscal (syclQueue_t device_queue, int64_t n, uint16_t alpha,
666
+ extern " C" int onemklHscal (syclQueue_t device_queue, int64_t n, uint16_t * alpha,
665
667
short *x, int64_t incx) {
666
- auto status = oneapi::mkl::blas::column_major::scal (device_queue->val , n, sycl::bit_cast<sycl::half>(alpha),
668
+ auto status = oneapi::mkl::blas::column_major::scal (device_queue->val , n,
669
+ *reinterpret_cast <const sycl::half *>(alpha),
667
670
reinterpret_cast <sycl::half *>(x), incx, {});
668
671
__FORCE_MKL_FLUSH__ (status);
669
672
return 0 ;
0 commit comments