Skip to content

Commit 27208bf

Browse files
authored
CUDA: add bf16 and f32 support to cublas_mul_mat_batched (#14361)
* CUDA: add bf16 and f32 support to cublas_mul_mat_batched * Review: add type traits and make function more generic * Review: make check more explicit, add back comments, and fix formatting * Review: fix formatting, remove useless type conversion, fix naming for bools
1 parent 63a7bb3 commit 27208bf

File tree

4 files changed

+162
-78
lines changed

4 files changed

+162
-78
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728728
return nullptr;
729729
}
730730
}
731+
732+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733+
switch (type) {
734+
case GGML_TYPE_F32:
735+
return convert_unary_cuda<float, nv_bfloat16>;
736+
case GGML_TYPE_F16:
737+
return convert_unary_cuda<half, nv_bfloat16>;
738+
default:
739+
return nullptr;
740+
}
741+
}
742+
743+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744+
switch (type) {
745+
case GGML_TYPE_F16:
746+
return convert_unary_cuda<half, float>;
747+
case GGML_TYPE_BF16:
748+
return convert_unary_cuda<nv_bfloat16, float>;
749+
default:
750+
return nullptr;
751+
}
752+
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
2222
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
2323
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
2424

25+
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
2526
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
27+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28+
29+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
2630
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
31+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 131 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
17491749
}
17501750

17511751
static __global__ void k_compute_batched_ptrs(
1752-
const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752+
const void * src0_as_f16, const void * src1_as_f16, char * dst,
17531753
const void ** ptrs_src, void ** ptrs_dst,
17541754
int64_t ne12, int64_t ne13,
17551755
int64_t ne23,
@@ -1772,91 +1772,139 @@ static __global__ void k_compute_batched_ptrs(
17721772
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
17731773
}
17741774

1775-
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775+
// Type traits for mapping ggml types to CUDA/cuBLAS types
1776+
template<ggml_type T>
1777+
struct batched_mul_mat_traits;
1778+
1779+
template<>
1780+
struct batched_mul_mat_traits<GGML_TYPE_F32> {
1781+
using cuda_type = float;
1782+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783+
static inline const cudaDataType_t data_type = CUDA_R_32F;
1784+
static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785+
static inline const float alpha = 1.0f;
1786+
static inline const float beta = 0.0f;
1787+
static inline const void* get_alpha() { static const float val = alpha; return &val; }
1788+
static inline const void* get_beta() { static const float val = beta; return &val; }
1789+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1790+
};
1791+
1792+
template<>
1793+
struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1794+
using cuda_type = nv_bfloat16;
1795+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796+
static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797+
static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798+
static inline const float alpha = 1.0f;
1799+
static inline const float beta = 0.0f;
1800+
static inline const void* get_alpha() { static const float val = alpha; return &val; }
1801+
static inline const void* get_beta() { static const float val = beta; return &val; }
1802+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1803+
};
1804+
1805+
template<>
1806+
struct batched_mul_mat_traits<GGML_TYPE_F16> {
1807+
using cuda_type = half;
1808+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809+
static inline const cudaDataType_t data_type = CUDA_R_16F;
1810+
static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811+
static inline const half alpha = 1.0;
1812+
static inline const half beta = 0.0;
1813+
static inline const void* get_alpha() { static const half val = alpha; return &val; }
1814+
static inline const void* get_beta() { static const half val = beta; return &val; }
1815+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1816+
};
1817+
1818+
template<ggml_type src0_type>
1819+
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820+
using traits = batched_mul_mat_traits<src0_type>;
1821+
using cuda_t = typename traits::cuda_type;
1822+
17761823
GGML_ASSERT(!ggml_is_transposed(src0));
17771824
GGML_ASSERT(!ggml_is_transposed(src1));
1778-
17791825
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1780-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1826+
GGML_ASSERT(src0->type == src0_type);
1827+
GGML_ASSERT(ggml_is_contiguous(dst));
17811828

17821829
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17831830
// As long as dst is contiguous this does not matter though.
1784-
GGML_ASSERT(ggml_is_contiguous(dst));
17851831

17861832
GGML_TENSOR_BINARY_OP_LOCALS
17871833

17881834
const int64_t ne_dst = ggml_nelements(dst);
1789-
17901835
cudaStream_t main_stream = ctx.stream();
1791-
17921836
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
17931837

1794-
const half * src0_f16 = (const half *) src0->data;
17951838
float * dst_ddf = (float *) dst->data;
1796-
1797-
const half * src1_f16 = (const half *) src1->data;
17981839
const size_t ts_src1 = ggml_type_size(src1->type);
17991840
GGML_ASSERT(nb10 == ts_src1);
18001841
int64_t s11 = nb11 / ts_src1;
18011842
int64_t s12 = nb12 / ts_src1;
18021843
int64_t s13 = nb13 / ts_src1;
1803-
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
18041844

1805-
// convert src1 to fp16
1806-
if (src1->type != GGML_TYPE_F16) {
1807-
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1808-
const int64_t ne_src1 = ggml_nelements(src1);
1809-
src1_f16_alloc.alloc(ne_src1);
1810-
GGML_ASSERT(to_fp16_cuda != nullptr);
1845+
const cuda_t * src0_ptr = nullptr;
1846+
const cuda_t * src1_ptr = nullptr;
18111847

1812-
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848+
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849+
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850+
1851+
// Handle src0
1852+
src0_ptr = (const cuda_t *) src0->data;
1853+
1854+
// Handle src1 - convert if necessary
1855+
if (src1->type == src0_type) {
1856+
src1_ptr = (const cuda_t *) src1->data;
1857+
} else {
1858+
// Convert src1 to target type using traits conversion functions
1859+
const int64_t ne_src1 = ggml_nelements(src1);
1860+
src1_alloc.alloc(ne_src1);
18131861

1814-
src1_f16 = src1_f16_alloc.get();
1862+
const auto convert_func = traits::get_nc_converter(src1->type);
1863+
GGML_ASSERT(convert_func != nullptr);
1864+
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865+
src1_ptr = src1_alloc.get();
18151866
s11 = ne10;
18161867
s12 = ne11*s11;
18171868
s13 = ne12*s12;
18181869
}
18191870

1820-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1871+
// Setup destination buffer
1872+
ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
18211873
char * dst_t;
1822-
1823-
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824-
cudaDataType_t cu_data_type = CUDA_R_16F;
1825-
1826-
// dst strides
18271874
size_t nbd2 = dst->nb[2];
18281875
size_t nbd3 = dst->nb[3];
18291876

1830-
const half alpha_f16 = 1.0f;
1831-
const half beta_f16 = 0.0f;
1832-
1877+
cublasComputeType_t cu_compute_type = traits::compute_type;
1878+
cudaDataType_t cu_data_type = traits::data_type;
1879+
cudaDataType_t cu_data_type_a = traits::data_type;
1880+
cudaDataType_t cu_data_type_b = traits::data_type;
1881+
const void * alpha = traits::get_alpha();
1882+
const void * beta = traits::get_beta();
18331883
const float alpha_f32 = 1.0f;
1834-
const float beta_f32 = 0.0f;
1835-
1836-
const void * alpha = &alpha_f16;
1837-
const void * beta = &beta_f16;
1884+
const float beta_f32 = 0.0f;
18381885

18391886
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1840-
dst_t = (char *) dst_f16.alloc(ne_dst);
1841-
1842-
nbd2 /= sizeof(float) / sizeof(half);
1843-
nbd3 /= sizeof(float) / sizeof(half);
1887+
if constexpr (src0_type == GGML_TYPE_F32) {
1888+
dst_t = (char *) dst_ddf; // Direct F32 output
1889+
} else {
1890+
dst_t = (char *) dst_temp.alloc(ne_dst);
1891+
nbd2 /= sizeof(float) / sizeof(cuda_t);
1892+
nbd3 /= sizeof(float) / sizeof(cuda_t);
1893+
}
18441894
} else {
18451895
dst_t = (char *) dst_ddf;
1846-
18471896
cu_compute_type = CUBLAS_COMPUTE_32F;
1848-
cu_data_type = CUDA_R_32F;
1849-
1897+
cu_data_type = CUDA_R_32F;
18501898
alpha = &alpha_f32;
1851-
beta = &beta_f32;
1899+
beta = &beta_f32;
18521900
}
18531901

18541902
int id = ggml_cuda_get_device();
18551903
const int cc = ggml_cuda_info().devices[id].cc;
18561904
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
18571905
cu_compute_type = CUBLAS_COMPUTE_32F;
18581906
alpha = &alpha_f32;
1859-
beta = &beta_f32;
1907+
beta = &beta_f32;
18601908
}
18611909

18621910
GGML_ASSERT(ne12 % ne02 == 0);
@@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18661914
const int64_t r2 = ne12/ne02;
18671915
const int64_t r3 = ne13/ne03;
18681916

1869-
#if 0
1870-
// use cublasGemmEx
1871-
{
1872-
for (int i13 = 0; i13 < ne13; ++i13) {
1873-
for (int i12 = 0; i12 < ne12; ++i12) {
1874-
int i03 = i13 / r3;
1875-
int i02 = i12 / r2;
1876-
1877-
CUBLAS_CHECK(
1878-
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879-
ne01, ne11, ne10,
1880-
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881-
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882-
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883-
cu_compute_type,
1884-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885-
}
1886-
}
1887-
}
1888-
#else
18891917
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18901918
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18911919
// use cublasGemmStridedBatchedEx
18921920
CUBLAS_CHECK(
18931921
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18941922
ne01, ne11, ne10,
1895-
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896-
src1_f16, CUDA_R_16F, s11, s12, // strideB
1897-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1923+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924+
src1_ptr, cu_data_type_b, s11, s12, // strideB
1925+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
18981926
ne12*ne13,
18991927
cu_compute_type,
19001928
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19051933
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
19061934
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
19071935

1936+
size_t src1_stride_size = sizeof(cuda_t);
1937+
19081938
dim3 block_dims(ne13, ne12);
19091939
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1910-
src0_f16, src1_f16, dst_t,
1940+
src0_ptr, src1_ptr, dst_t,
19111941
ptrs_src.get(), ptrs_dst.get(),
19121942
ne12, ne13,
19131943
ne23,
19141944
nb02, nb03,
1915-
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1916-
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1945+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1946+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
19171947
nbd2, nbd3,
19181948
r2, r3);
1949+
19191950
CUDA_CHECK(cudaGetLastError());
19201951

19211952
CUBLAS_CHECK(
19221953
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19231954
ne01, ne11, ne10,
1924-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1925-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1926-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1955+
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1956+
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19271958
ne23,
19281959
cu_compute_type,
19291960
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19301961
}
1931-
#endif
19321962

1933-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1935-
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1963+
// Convert output back to F32 if needed
1964+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966+
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967+
}
1968+
}
1969+
1970+
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972+
1973+
switch (src0->type) {
1974+
case GGML_TYPE_F32:
1975+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976+
break;
1977+
case GGML_TYPE_BF16:
1978+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979+
break;
1980+
case GGML_TYPE_F16:
1981+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982+
break;
1983+
default:
1984+
GGML_ABORT("Unsupported type");
19361985
}
19371986
}
19381987

@@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19842033
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19852034
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19862035

2036+
//TODO update for generic tensor parallelism
2037+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038+
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040+
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041+
19872042
if (!split && use_mul_mat_vec) {
19882043
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
19892044
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19922047
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19932048
} else if (!split && use_mul_mat_q) {
19942049
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1995-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996-
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2050+
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051+
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19972052
// general KQ + KQV multi-batch without FlashAttention
19982053
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19992054
} else if (use_mul_mat_vec) {

tests/test-backend-ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4425,8 +4425,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
44254425
for (auto nr : {1,4}) {
44264426
for (uint32_t m = 0; m < 2; ++m) {
44274427
for (uint32_t k = 0; k < 2; ++k) {
4428-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4429-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4428+
for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
4429+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4430+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4431+
}
44304432
}
44314433
}
44324434
}

0 commit comments

Comments
 (0)