Skip to content

Commit a9478f0

Browse files
am17anMinh141120
authored andcommitted
CUDA: add bf16 and f32 support to cublas_mul_mat_batched (ggml-org#14361)
* CUDA: add bf16 and f32 support to cublas_mul_mat_batched * Review: add type traits and make function more generic * Review: make check more explicit, add back comments, and fix formatting * Review: fix formatting, remove useless type conversion, fix naming for bools
1 parent 10d2f66 commit a9478f0

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,16 +1842,27 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18421842
int64_t s12 = nb12 / ts_src1;
18431843
int64_t s13 = nb13 / ts_src1;
18441844

1845-
// convert src1 to fp16
1846-
if (src1->type != GGML_TYPE_F16) {
1847-
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1848-
const int64_t ne_src1 = ggml_nelements(src1);
1849-
src1_f16_alloc.alloc(ne_src1);
1850-
GGML_ASSERT(to_fp16_cuda != nullptr);
1845+
const cuda_t * src0_ptr = nullptr;
1846+
const cuda_t * src1_ptr = nullptr;
1847+
1848+
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849+
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
18511850

1852-
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851+
// Handle src0
1852+
src0_ptr = (const cuda_t *) src0->data;
1853+
1854+
// Handle src1 - convert if necessary
1855+
if (src1->type == src0_type) {
1856+
src1_ptr = (const cuda_t *) src1->data;
1857+
} else {
1858+
// Convert src1 to target type using traits conversion functions
1859+
const int64_t ne_src1 = ggml_nelements(src1);
1860+
src1_alloc.alloc(ne_src1);
18531861

1854-
src1_f16 = src1_f16_alloc.get();
1862+
const auto convert_func = traits::get_nc_converter(src1->type);
1863+
GGML_ASSERT(convert_func != nullptr);
1864+
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865+
src1_ptr = src1_alloc.get();
18551866
s11 = ne10;
18561867
s12 = ne11*s11;
18571868
s13 = ne12*s12;
@@ -1948,11 +1959,29 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19481959
cu_compute_type,
19491960
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19501961
}
1951-
#endif
19521962

1953-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1954-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1955-
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1963+
// Convert output back to F32 if needed
1964+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966+
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967+
}
1968+
}
1969+
1970+
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972+
1973+
switch (src0->type) {
1974+
case GGML_TYPE_F32:
1975+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976+
break;
1977+
case GGML_TYPE_BF16:
1978+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979+
break;
1980+
case GGML_TYPE_F16:
1981+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982+
break;
1983+
default:
1984+
GGML_ABORT("Unsupported type");
19561985
}
19571986
}
19581987

@@ -2004,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20042033
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20052034
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20062035

2036+
//TODO update for generic tensor parallelism
2037+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038+
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040+
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041+
20072042
if (!split && use_mul_mat_vec) {
20082043
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
20092044
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)

0 commit comments

Comments
 (0)