Skip to content

Commit 90f6637

Browse files
CISCNexesenex
authored andcommitted
cuda : add f32 to bf16 copy op (ggml-org#12806)
This allows BF16 KV-cache on CUDA.
1 parent a934ab7 commit 90f6637

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1010
*dsti = *xi;
1111
}
1212

13+
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
14+
const float * xi = (const float *) cxi;
15+
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
16+
17+
*dsti = *xi;
18+
}
19+
1320
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1421
const float * xi = (const float *) cxi;
1522
half * dsti = (half *) cdsti;
@@ -436,6 +443,16 @@ static void ggml_cpy_f32_f32_cuda(
436443
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
437444
}
438445

446+
static void ggml_cpy_f32_bf16_cuda(
447+
const char * cx, char * cdst, const int ne,
448+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
449+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
450+
451+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
452+
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
453+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
454+
}
455+
439456
static void ggml_cpy_f32_f16_cuda(
440457
const char * cx, char * cdst, const int ne,
441458
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -666,6 +683,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
666683
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
667684
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
668685
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
686+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
687+
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
669688
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
670689
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
671690
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
@@ -723,6 +742,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
723742
return nullptr;
724743
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
725744
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
745+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
746+
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
726747
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
727748
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
728749
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3325,6 +3325,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33253325
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
33263326
return true;
33273327
}
3328+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3329+
return true;
3330+
}
33283331
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
33293332
return true;
33303333
}

0 commit comments

Comments
 (0)