Skip to content

Commit c13fcfb

Browse files
committed
cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops)
1 parent 84d4ca0 commit c13fcfb

File tree

1 file changed

+165
-4
lines changed

1 file changed

+165
-4
lines changed

ggml-cuda.cu

Lines changed: 165 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7013,7 +7013,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
70137013
}
70147014

70157015
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7016-
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
7016+
GGML_ASSERT(!ggml_is_transposed(src0));
7017+
GGML_ASSERT(!ggml_is_transposed(src1));
70177018
GGML_ASSERT(!ggml_is_permuted(src0));
70187019
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
70197020
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -7023,11 +7024,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70237024
const int64_t ne01 = src0->ne[1];
70247025
const int64_t ne02 = src0->ne[2];
70257026

7026-
const int64_t ne12 = src1->ne[2];
7027-
70287027
const int64_t nb01 = src0->nb[1];
70297028
const int64_t nb02 = src0->nb[2];
70307029

7030+
const int64_t ne12 = src1->ne[2];
7031+
70317032
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
70327033
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
70337034

@@ -7046,6 +7047,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70467047
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
70477048
}
70487049

7050+
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7051+
GGML_ASSERT(!ggml_is_transposed(src0));
7052+
GGML_ASSERT(!ggml_is_transposed(src1));
7053+
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
7054+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
7055+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7056+
7057+
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
7058+
const int64_t ne01 = src0->ne[1];
7059+
const int64_t ne02 = src0->ne[2];
7060+
const int64_t ne03 = src0->ne[3];
7061+
7062+
const int64_t nb01 = src0->nb[1];
7063+
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
7064+
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
7065+
7066+
const int64_t ne10 = src1->ne[0];
7067+
const int64_t ne11 = src1->ne[1];
7068+
const int64_t ne12 = src1->ne[2];
7069+
const int64_t ne13 = src1->ne[3];
7070+
7071+
const int64_t nb11 = src1->nb[1];
7072+
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
7073+
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
7074+
7075+
const int64_t ne1 = ggml_nelements(src1);
7076+
const int64_t ne = ggml_nelements(dst);
7077+
7078+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7079+
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7080+
7081+
int id;
7082+
CUDA_CHECK(cudaGetDevice(&id));
7083+
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
7084+
7085+
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7086+
void * src0_ddq = src0_extra->data_device[g_main_device];
7087+
half * src0_as_f16 = (half *) src0_ddq;
7088+
7089+
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
7090+
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
7091+
7092+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
7093+
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
7094+
7095+
// convert src1 to fp16
7096+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
7097+
GGML_ASSERT(to_fp16_cuda != nullptr);
7098+
7099+
size_t src1_as = 0;
7100+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7101+
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
7102+
7103+
size_t dst_as = 0;
7104+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7105+
7106+
GGML_ASSERT(ne12 % ne02 == 0);
7107+
GGML_ASSERT(ne13 % ne03 == 0);
7108+
7109+
// broadcast factors
7110+
const int64_t r2 = ne12/ne02;
7111+
const int64_t r3 = ne13/ne03;
7112+
7113+
const half alpha_f16 = 1.0f;
7114+
const half beta_f16 = 0.0f;
7115+
7116+
#if 0
7117+
// use cublasGemmEx
7118+
{
7119+
for (int i13 = 0; i13 < ne13; ++i13) {
7120+
for (int i12 = 0; i12 < ne12; ++i12) {
7121+
int i03 = i13 / r3;
7122+
int i02 = i12 / r2;
7123+
7124+
CUBLAS_CHECK(
7125+
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7126+
ne01, ne11, ne10,
7127+
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7128+
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7129+
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7130+
CUBLAS_COMPUTE_16F,
7131+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7132+
}
7133+
}
7134+
}
7135+
#else
7136+
// use cublasGemmBatchedEx
7137+
{
7138+
const int ne23 = ne12*ne13;
7139+
7140+
// TODO: avoid this alloc
7141+
void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *));
7142+
void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *));
7143+
void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *));
7144+
7145+
for (int i13 = 0; i13 < ne13; ++i13) {
7146+
for (int i12 = 0; i12 < ne12; ++i12) {
7147+
int i03 = i13 / r3;
7148+
int i02 = i12 / r2;
7149+
7150+
src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
7151+
src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7152+
dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7153+
}
7154+
}
7155+
7156+
// allocate device memory for pointers
7157+
void ** src0_ptrs_as = nullptr;
7158+
void ** src1_ptrs_as = nullptr;
7159+
void ** dst_ptrs_as = nullptr;
7160+
7161+
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
7162+
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
7163+
CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *)));
7164+
7165+
// copy pointers to device
7166+
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
7167+
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
7168+
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
7169+
7170+
CUBLAS_CHECK(
7171+
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7172+
ne01, ne11, ne10,
7173+
&alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
7174+
(void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
7175+
&beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01,
7176+
ne23,
7177+
CUBLAS_COMPUTE_16F,
7178+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7179+
7180+
// free device memory for pointers
7181+
CUDA_CHECK(cudaFree(src0_ptrs_as));
7182+
CUDA_CHECK(cudaFree(src1_ptrs_as));
7183+
CUDA_CHECK(cudaFree( dst_ptrs_as));
7184+
7185+
free(src0_ptrs);
7186+
free(src1_ptrs);
7187+
free( dst_ptrs);
7188+
}
7189+
#endif
7190+
7191+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7192+
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7193+
7194+
ggml_cuda_pool_free(src1_as_f16, src1_as);
7195+
ggml_cuda_pool_free(dst_f16, dst_as);
7196+
}
7197+
70497198
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
70507199
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
70517200
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7058,10 +7207,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
70587207
}
70597208
}
70607209

7210+
// debug helpers
7211+
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
7212+
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
7213+
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
7214+
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
7215+
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7216+
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7217+
70617218
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7219+
// KQ
70627220
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7063-
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
7221+
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7222+
// KQV
70647223
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7224+
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
7225+
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
70657226
} else if (src0->type == GGML_TYPE_F32) {
70667227
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
70677228
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {

0 commit comments

Comments
 (0)