@@ -358,6 +358,10 @@ void SELayer<sycl::half>::Eval(int N, sycl::half* output, const sycl::half* inpu
358
358
half alpha = one_h;
359
359
half beta = zero_h;
360
360
361
+ #elif defined(USE_HIPBLAS)
362
+ hipblasHalf alpha{1 };
363
+ hipblasHalf beta{0 };
364
+
361
365
#else
362
366
sycl::half alpha = 1 ;
363
367
sycl::half beta = 0 ;
@@ -393,10 +397,10 @@ void SELayer<sycl::half>::Eval(int N, sycl::half* output, const sycl::half* inpu
393
397
sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
394
398
hipblasSetStream (handle, hipStreamHandle);
395
399
396
- hipblasSgemm (handle, transpose_type_transpose,
400
+ hipblasHgemm (handle, transpose_type_transpose,
397
401
transpose_type_notranspose,numFc1Out_, N, C, &alpha,
398
- ((const sycl::half *)w1_), C, ((const sycl::half *)op2), C,
399
- &beta, ((sycl::half *)op1), numFc1Out_);
402
+ ((const hipblasHalf *)w1_), C, ((const hipblasHalf *)op2), C,
403
+ &beta, ((hipblasHalf *)op1), numFc1Out_);
400
404
401
405
hipStreamSynchronize (hipStreamHandle);
402
406
});
@@ -436,10 +440,10 @@ void SELayer<sycl::half>::Eval(int N, sycl::half* output, const sycl::half* inpu
436
440
sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
437
441
hipblasSetStream (handle, hipStreamHandle);
438
442
439
- hipblasSgemm (
443
+ hipblasHgemm (
440
444
handle, transpose_type_transpose, transpose_type_notranspose, 2 * C,
441
- N, numFc1Out_, &alpha,((const sycl::half *)w2_), numFc1Out_,
442
- ((const sycl::half *)op1), numFc1Out_, &beta, ((sycl::half *)op2),
445
+ N, numFc1Out_, &alpha,((const hipblasHalf *)w2_), numFc1Out_,
446
+ ((const hipblasHalf *)op1), numFc1Out_, &beta, ((hipblasHalf *)op2),
443
447
2 * C);
444
448
445
449
hipStreamSynchronize (hipStreamHandle);
@@ -544,6 +548,10 @@ template <>
544
548
half alpha = one_h;
545
549
half beta = zero_h;
546
550
551
+ #elif defined(USE_HIPBLAS)
552
+ hipblasHalf alpha{1 };
553
+ hipblasHalf beta{0 };
554
+
547
555
#else
548
556
sycl::half alpha = 1 ;
549
557
sycl::half beta = 0 ;
@@ -576,11 +584,11 @@ template <>
576
584
sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
577
585
hipblasSetStream (handle, hipStreamHandle);
578
586
579
- hipblasSgemm (
587
+ hipblasHgemm (
580
588
handle, transpose_type_transpose, transpose_type_notranspose,
581
- num_outputs, N, num_inputs, &alpha, ((const sycl::half *)weights_),
582
- num_inputs, ((const sycl::half *)input_tensor), num_inputs, &beta,
583
- ((sycl::half *)output_tensor), num_outputs);
589
+ num_outputs, N, num_inputs, &alpha, ((const hipblasHalf *)weights_),
590
+ num_inputs, ((const hipblasHalf *)input_tensor), num_inputs, &beta,
591
+ ((hipblasHalf *)output_tensor), num_outputs);
584
592
585
593
hipStreamSynchronize (hipStreamHandle);
586
594
});
@@ -964,7 +972,7 @@ template <>
964
972
965
973
hipStreamSynchronize (hipStreamHandle);
966
974
});
967
- );
975
+ } );
968
976
#else
969
977
int64_t M_ = M;
970
978
int64_t N_ = N;
@@ -1807,7 +1815,20 @@ static void cublasXgemm(transpose_type transa,
1807
1815
});
1808
1816
}
1809
1817
#elif defined(USE_HIPBLAS)
1810
- hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t ();
1818
+ hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t ();
1819
+ if (fp16) {
1820
+ unsigned short alpha_h = FP32toFP16 (alpha);
1821
+ unsigned short beta_h = FP32toFP16 (beta);
1822
+ sycl_queue.submit ([&](sycl::handler &cgh) {
1823
+ cgh.host_task ([=](sycl::interop_handle ih) {
1824
+ auto hipStreamHandle = sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
1825
+ hipblasSetStream (handle, hipStreamHandle);
1826
+ hipblasHgemm (handle, transa, transb, m, n, k, &alpha_h, (const hipblasHalf*)A,
1827
+ lda, (const hipblasHalf*)B, ldb, &beta_h, (hipblasHalf*)C, ldc);
1828
+ hipStreamSynchronize (hipStreamHandle);
1829
+ });
1830
+ });
1831
+ } else {
1811
1832
sycl_queue.submit ([&](sycl::handler &cgh) {
1812
1833
cgh.host_task ([=](sycl::interop_handle ih) {
1813
1834
auto hipStreamHandle = sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
@@ -1816,6 +1837,7 @@ static void cublasXgemm(transpose_type transa,
1816
1837
hipStreamSynchronize (hipStreamHandle);
1817
1838
});
1818
1839
});
1840
+ }
1819
1841
#else
1820
1842
oneapi::mkl::blas::column_major::gemm (sycl_queue, transa, transb, m, n, k, alpha, (const DataType *)A, lda,
1821
1843
(const DataType *)B, ldb, beta, (DataType *)C, ldc);
@@ -1873,9 +1895,29 @@ static void cublasXGemmStridedBatched(transpose_type transa, transpose_type tran
1873
1895
});
1874
1896
}
1875
1897
#elif defined(USE_HIPBLAS)
1876
- hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t ();
1898
+ hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t ();
1899
+ if (fp16) {
1900
+ unsigned short alpha_h = FP32toFP16 (alpha);
1901
+ unsigned short beta_h = FP32toFP16 (beta);
1902
+
1903
+ sycl_queue.submit ([&](sycl::handler &cgh) {
1904
+
1905
+ cgh.host_task ([=](sycl::interop_handle ih) {
1906
+
1907
+ auto hipStreamHandle = sycl::get_native<sycl::backend::ext_oneapi_hip>(sycl_queue);
1908
+ hipblasSetStream (handle, hipStreamHandle);
1909
+
1910
+ hipblasGemmStridedBatchedEx (
1911
+ handle, transa, transb, m, n, k, &alpha_h, A, HIPBLAS_R_16F, lda, strideA, B,
1912
+ HIPBLAS_R_16F, ldb, strideB, &beta_h, C, HIPBLAS_R_16F, ldc, strideC,
1913
+ batchCount, HIPBLAS_R_16F, HIPBLAS_GEMM_DEFAULT);
1914
+
1915
+ hipStreamSynchronize (hipStreamHandle);
1877
1916
1878
- sycl_queue.submit ([&](sycl::handler &cgh) {
1917
+ });
1918
+ });
1919
+ } else {
1920
+ sycl_queue.submit ([&](sycl::handler &cgh) {
1879
1921
1880
1922
cgh.host_task ([=](sycl::interop_handle ih) {
1881
1923
@@ -1891,9 +1933,10 @@ static void cublasXGemmStridedBatched(transpose_type transa, transpose_type tran
1891
1933
1892
1934
});
1893
1935
});
1894
- #else
1895
- oneapi::mkl::blas::column_major::gemm_batch (sycl_queue, transa, transb, m, n, k, alpha, (const DataType *)A, lda, strideA, (const DataType *)B, ldb, strideB, beta, (DataType *)C, ldc, strideC, batchCount);
1896
- #endif
1936
+ }
1937
+ #else
1938
+ oneapi::mkl::blas::column_major::gemm_batch (sycl_queue, transa, transb, m, n, k, alpha, (const DataType *)A, lda, strideA, (const DataType *)B, ldb, strideB, beta, (DataType *)C, ldc, strideC, batchCount);
1939
+ #endif
1897
1940
}
1898
1941
1899
1942
template <typename DataType>
@@ -1962,8 +2005,8 @@ static void cublasXGemmBatched(transpose_type transa,
1962
2005
hipblasSetStream (handle, hipStreamHandle);
1963
2006
1964
2007
hipblasHgemmBatched (
1965
- handle, transa, transb, m, n, k, (const half *)&alpha_h, (half **)A, lda,
1966
- (half **)B, ldb, (const half *)&beta_h, (half **)C, ldc, batchCount);
2008
+ handle, transa, transb, m, n, k, (const hipblasHalf *)&alpha_h, (hipblasHalf **)A, lda,
2009
+ (hipblasHalf **)B, ldb, (const hipblasHalf *)&beta_h, (hipblasHalf **)C, ldc, batchCount);
1967
2010
1968
2011
hipStreamSynchronize (hipStreamHandle);
1969
2012
@@ -2507,7 +2550,6 @@ template <typename DataType>
2507
2550
AttentionBody<DataType>::~AttentionBody () {
2508
2551
sycl::free (ip_emb_w_, sycl_queue_);
2509
2552
sycl::free (ip_emb_b_, sycl_queue_);
2510
- sycl::free (pos_encoding_, sycl_queue_);
2511
2553
if (is_pe_dense_embedding_) {
2512
2554
sycl::free (ip_emb_pre_w_, sycl_queue_);
2513
2555
sycl::free (ip_emb_pre_b_, sycl_queue_);
0 commit comments