Skip to content

Commit 7c8a37b

Browse files
committed
cuda : speed-up by using CUBLAS_COMPUTE_32F instead of CUBLAS_COMPUTE_16F (ggml-org#3816)
2 parents 283f248 + c830a05 commit 7c8a37b

File tree

1 file changed

+23
-37
lines changed

1 file changed

+23
-37
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6607,27 +6607,20 @@ inline void ggml_cuda_op_mul_mat_cublas(
66076607
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
66086608
}
66096609
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
6610-
size_t dst_as = 0;
6611-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
66126610

6613-
const half alpha_f16 = 1.0f;
6614-
const half beta_f16 = 0.0f;
6611+
const float alpha = 1.0f;
6612+
const float beta = 0.0f;
66156613

66166614
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
66176615
CUBLAS_CHECK(
66186616
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
66196617
row_diff, src1_ncols, ne10,
6620-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
6621-
src1_ptr, CUDA_R_16F, ne10,
6622-
&beta_f16, dst_f16, CUDA_R_16F, ldc,
6623-
CUBLAS_COMPUTE_16F,
6618+
&alpha, src0_ptr, CUDA_R_16F, ne00,
6619+
src1_ptr, CUDA_R_16F, ne10,
6620+
&beta, dst_dd_i, CUDA_R_32F, ldc,
6621+
CUBLAS_COMPUTE_32F,
66246622
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
66256623

6626-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6627-
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6628-
6629-
ggml_cuda_pool_free(dst_f16, dst_as);
6630-
66316624
if (src0_as != 0) {
66326625
ggml_cuda_pool_free(src0_as_f16, src0_as);
66336626
}
@@ -7436,7 +7429,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
74367429
}
74377430

74387431
__global__ static void k_compute_batched_ptrs(
7439-
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7432+
const half * src0_as_f16, const half * src1_as_f16, float * dst_f32,
74407433
const void ** ptrs_src, void ** ptrs_dst,
74417434
int ne12, int ne13,
74427435
int ne23,
@@ -7456,7 +7449,7 @@ __global__ static void k_compute_batched_ptrs(
74567449

74577450
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
74587451
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7459-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
7452+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f32 + i12* nb2 + i13* nb3 ;
74607453
}
74617454

74627455
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7513,18 +7506,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
75137506
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
75147507
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
75157508

7516-
size_t dst_as = 0;
7517-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7518-
75197509
GGML_ASSERT(ne12 % ne02 == 0);
75207510
GGML_ASSERT(ne13 % ne03 == 0);
75217511

75227512
// broadcast factors
75237513
const int64_t r2 = ne12/ne02;
75247514
const int64_t r3 = ne13/ne03;
75257515

7526-
const half alpha_f16 = 1.0f;
7527-
const half beta_f16 = 0.0f;
7516+
const float alpha = 1.0f;
7517+
const float beta = 0.0f;
75287518

75297519
#if 0
75307520
// use cublasGemmEx
@@ -7537,10 +7527,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
75377527
CUBLAS_CHECK(
75387528
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
75397529
ne01, ne11, ne10,
7540-
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7541-
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7542-
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7543-
CUBLAS_COMPUTE_16F,
7530+
&alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7531+
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7532+
&beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
7533+
CUBLAS_COMPUTE_32F,
75447534
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
75457535
}
75467536
}
@@ -7552,11 +7542,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
75527542
CUBLAS_CHECK(
75537543
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
75547544
ne01, ne11, ne10,
7555-
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7556-
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7557-
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
7545+
&alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7546+
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7547+
&beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
75587548
ne12*ne13,
7559-
CUBLAS_COMPUTE_16F,
7549+
CUBLAS_COMPUTE_32F,
75607550
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
75617551
} else {
75627552
// use cublasGemmBatchedEx
@@ -7573,7 +7563,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
75737563

75747564
dim3 block_dims(ne13, ne12);
75757565
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
7576-
src0_as_f16, src1_as_f16, dst_f16,
7566+
src0_as_f16, src1_as_f16, dst_ddf,
75777567
ptrs_src, ptrs_dst,
75787568
ne12, ne13,
75797569
ne23,
@@ -7586,11 +7576,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
75867576
CUBLAS_CHECK(
75877577
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
75887578
ne01, ne11, ne10,
7589-
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7590-
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7591-
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
7579+
&alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7580+
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7581+
&beta, ( void **) (ptrs_dst + 0*ne23), CUDA_R_32F, ne01,
75927582
ne23,
7593-
CUBLAS_COMPUTE_16F,
7583+
CUBLAS_COMPUTE_32F,
75947584
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
75957585

75967586
if (ptrs_src_s != 0) {
@@ -7602,11 +7592,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
76027592
}
76037593
#endif
76047594

7605-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7606-
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7607-
76087595
ggml_cuda_pool_free(src1_as_f16, src1_as);
7609-
ggml_cuda_pool_free(dst_f16, dst_as);
76107596
}
76117597

76127598
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)