Skip to content

Commit 1e44f3b

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent d918041 commit 1e44f3b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,14 +1202,12 @@ static void ggml_cuda_op_mul_mat_cublas(
12021202

12031203
const int cc = ggml_cuda_info().devices[id].cc;
12041204

1205-
const bool support_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1205+
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
12061206
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12071207

1208-
const bool support_fp16 = (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1209-
GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12101208
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12111209

1212-
if (support_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1210+
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12131211
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12141212
if (src1->type != GGML_TYPE_BF16) {
12151213
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1237,7 +1235,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12371235

12381236
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12391237
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1240-
} else if (support_fp16 && use_fp16) {
1238+
} else if (fast_fp16_hardware_available(cc) && use_fp16) {
12411239
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12421240
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12431241
if (src0->type != GGML_TYPE_F16) {

0 commit comments

Comments
 (0)