@@ -1842,16 +1842,27 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1842
1842
int64_t s12 = nb12 / ts_src1;
1843
1843
int64_t s13 = nb13 / ts_src1;
1844
1844
1845
- // convert src1 to fp16
1846
- if (src1->type != GGML_TYPE_F16) {
1847
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1848
- const int64_t ne_src1 = ggml_nelements (src1);
1849
- src1_f16_alloc.alloc (ne_src1);
1850
- GGML_ASSERT (to_fp16_cuda != nullptr );
1845
+ const cuda_t * src0_ptr = nullptr ;
1846
+ const cuda_t * src1_ptr = nullptr ;
1847
+
1848
+ ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1849
+ ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
1851
1850
1852
- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data ;
1853
+
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data ;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements (src1);
1860
+ src1_alloc.alloc (ne_src1);
1853
1861
1854
- src1_f16 = src1_f16_alloc.get ();
1862
+ const auto convert_func = traits::get_nc_converter (src1->type );
1863
+ GGML_ASSERT (convert_func != nullptr );
1864
+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get ();
1855
1866
s11 = ne10;
1856
1867
s12 = ne11*s11;
1857
1868
s13 = ne12*s12;
@@ -1948,11 +1959,29 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1948
1959
cu_compute_type,
1949
1960
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1950
1961
}
1951
- #endif
1952
1962
1953
- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1954
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1955
- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (traits::ggml_type_val);
1966
+ to_fp32_cuda (dst_temp.get (), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type ) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break ;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break ;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break ;
1983
+ default :
1984
+ GGML_ABORT (" Unsupported type" );
1956
1985
}
1957
1986
}
1958
1987
@@ -2004,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2004
2033
// printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
2005
2034
// printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2006
2035
2036
+ // TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
2007
2042
if (!split && use_mul_mat_vec) {
2008
2043
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
2009
2044
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
0 commit comments