-
Notifications
You must be signed in to change notification settings - Fork 12.4k
cuda : add batched cuBLAS GEMM for faster attention #3749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
8fb1be6
6a30bf3
8d8d54f
84d4ca0
c13fcfb
878aa4f
d415669
3d297c1
27c34c0
d798a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4326,13 +4326,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | |
|
||
const half * x = (const half *) vx; | ||
|
||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y; | ||
const int channel = blockDim.z*blockIdx.z + threadIdx.z; | ||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y; | ||
const int channel = blockDim.z*blockIdx.z + threadIdx.z; | ||
const int channel_x = channel / channel_x_divisor; | ||
|
||
const int nrows_y = ncols_x; | ||
const int nrows_y = ncols_x; | ||
const int nrows_dst = nrows_x; | ||
const int row_dst = row_x; | ||
const int row_dst = row_x; | ||
|
||
const int idst = channel*nrows_dst + row_dst; | ||
|
||
|
@@ -4345,13 +4345,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | |
break; | ||
} | ||
|
||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; | ||
const float xi = __half2float(x[ix]); | ||
|
||
const int row_y = col_x; | ||
|
||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; | ||
const int iy = channel*nrows_y + row_y; | ||
|
||
const float xi = __half2float(x[ix]); | ||
|
||
tmp += xi * y[iy]; | ||
} | ||
|
||
|
@@ -7013,7 +7013,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens | |
} | ||
|
||
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ | ||
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); | ||
GGML_ASSERT(!ggml_is_transposed(src0)); | ||
GGML_ASSERT(!ggml_is_transposed(src1)); | ||
GGML_ASSERT(!ggml_is_permuted(src0)); | ||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); | ||
GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||
|
@@ -7023,11 +7024,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |
const int64_t ne01 = src0->ne[1]; | ||
const int64_t ne02 = src0->ne[2]; | ||
|
||
const int64_t ne12 = src1->ne[2]; | ||
|
||
const int64_t nb01 = src0->nb[1]; | ||
const int64_t nb02 = src0->nb[2]; | ||
|
||
const int64_t ne12 = src1->ne[2]; | ||
|
||
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); | ||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; | ||
|
||
|
@@ -7046,6 +7047,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); | ||
} | ||
|
||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ | ||
GGML_ASSERT(!ggml_is_transposed(src0)); | ||
GGML_ASSERT(!ggml_is_transposed(src1)); | ||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); | ||
GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||
GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||
|
||
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); | ||
const int64_t ne01 = src0->ne[1]; | ||
const int64_t ne02 = src0->ne[2]; | ||
const int64_t ne03 = src0->ne[3]; | ||
|
||
const int64_t nb01 = src0->nb[1]; | ||
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); | ||
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); | ||
|
||
const int64_t ne10 = src1->ne[0]; | ||
const int64_t ne11 = src1->ne[1]; | ||
const int64_t ne12 = src1->ne[2]; | ||
const int64_t ne13 = src1->ne[3]; | ||
|
||
const int64_t nb11 = src1->nb[1]; | ||
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); | ||
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); | ||
|
||
const int64_t ne1 = ggml_nelements(src1); | ||
const int64_t ne = ggml_nelements(dst); | ||
|
||
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); | ||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; | ||
|
||
int id; | ||
CUDA_CHECK(cudaGetDevice(&id)); | ||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); | ||
|
||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; | ||
void * src0_ddq = src0_extra->data_device[g_main_device]; | ||
half * src0_as_f16 = (half *) src0_ddq; | ||
|
||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; | ||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; | ||
|
||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; | ||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; | ||
|
||
// convert src1 to fp16 | ||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); | ||
GGML_ASSERT(to_fp16_cuda != nullptr); | ||
|
||
size_t src1_as = 0; | ||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); | ||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); | ||
|
||
size_t dst_as = 0; | ||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); | ||
|
||
GGML_ASSERT(ne12 % ne02 == 0); | ||
GGML_ASSERT(ne13 % ne03 == 0); | ||
|
||
// broadcast factors | ||
const int64_t r2 = ne12/ne02; | ||
const int64_t r3 = ne13/ne03; | ||
|
||
const half alpha_f16 = 1.0f; | ||
const half beta_f16 = 0.0f; | ||
|
||
#if 0 | ||
// use cublasGemmEx | ||
{ | ||
for (int i13 = 0; i13 < ne13; ++i13) { | ||
for (int i12 = 0; i12 < ne12; ++i12) { | ||
int i03 = i13 / r3; | ||
int i02 = i12 / r2; | ||
|
||
CUBLAS_CHECK( | ||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||
ne01, ne11, ne10, | ||
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), | ||
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), | ||
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, | ||
CUBLAS_COMPUTE_16F, | ||
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
} | ||
} | ||
} | ||
#else | ||
// use cublasGemmBatchedEx | ||
{ | ||
const int ne23 = ne12*ne13; | ||
|
||
// TODO: avoid this alloc | ||
void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *)); | ||
void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *)); | ||
void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *)); | ||
|
||
for (int i13 = 0; i13 < ne13; ++i13) { | ||
for (int i12 = 0; i12 < ne12; ++i12) { | ||
int i03 = i13 / r3; | ||
int i02 = i12 / r2; | ||
|
||
src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3]; | ||
src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; | ||
dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; | ||
} | ||
} | ||
|
||
// allocate device memory for pointers | ||
void ** src0_ptrs_as = nullptr; | ||
void ** src1_ptrs_as = nullptr; | ||
void ** dst_ptrs_as = nullptr; | ||
|
||
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *))); | ||
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *))); | ||
CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *))); | ||
|
||
// copy pointers to device | ||
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does anyone know if I changed these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You only need to run them on the same stream, but I don't think that this can be made async because the host memory may already be freed by the time the copy happens. Running memcpy asynchronously also requires using host pinned memory. If the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reduced the mallocs from 3 to 1, but when I try to replace it with |
||
|
||
CUBLAS_CHECK( | ||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||
ne01, ne11, ne10, | ||
&alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half), | ||
(void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float), | ||
&beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01, | ||
ne23, | ||
CUBLAS_COMPUTE_16F, | ||
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
KerfuffleV2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// free device memory for pointers | ||
CUDA_CHECK(cudaFree(src0_ptrs_as)); | ||
CUDA_CHECK(cudaFree(src1_ptrs_as)); | ||
CUDA_CHECK(cudaFree( dst_ptrs_as)); | ||
|
||
free(src0_ptrs); | ||
free(src1_ptrs); | ||
free( dst_ptrs); | ||
} | ||
#endif | ||
|
||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); | ||
|
||
ggml_cuda_pool_free(src1_as_f16, src1_as); | ||
ggml_cuda_pool_free(dst_f16, dst_as); | ||
} | ||
|
||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && | ||
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; | ||
|
@@ -7058,10 +7207,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 | |
} | ||
} | ||
|
||
// debug helpers | ||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); | ||
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); | ||
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); | ||
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); | ||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); | ||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); | ||
|
||
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { | ||
// KQ | ||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst); | ||
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { | ||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { | ||
// KQV | ||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst); | ||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { | ||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); | ||
} else if (src0->type == GGML_TYPE_F32) { | ||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); | ||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { | ||
|
Uh oh!
There was an error while loading. Please reload this page.