Skip to content

Commit 8110881

Browse files
committed
handle f16 input and f16 kernel, more opt
1 parent edc421c commit 8110881

File tree

4 files changed

+351
-137
lines changed

4 files changed

+351
-137
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ set(GGML_OPENCL_KERNELS
103103
pad
104104
repeat
105105
conv2d
106+
conv2d_f16_f32
106107
)
107108

108109
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ struct ggml_backend_opencl_context {
387387
cl_program program_tanh;
388388
cl_program program_upscale;
389389
cl_program program_concat;
390-
cl_program program_conv_2d;
390+
cl_program program_conv_2d_f16;
391+
cl_program program_conv_2d_f32;
392+
cl_program program_conv_2d_f16_f32;
391393
cl_program program_tsembd;
392394
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
393395

@@ -434,7 +436,9 @@ struct ggml_backend_opencl_context {
434436
cl_kernel kernel_upscale_bilinear;
435437
cl_kernel kernel_concat_f32_contiguous;
436438
cl_kernel kernel_concat_f32_non_contiguous;
437-
cl_kernel kernel_conv_2d;
439+
cl_kernel kernel_conv_2d_f16;
440+
cl_kernel kernel_conv_2d_f32;
441+
cl_kernel kernel_conv_2d_f16_f32;
438442
cl_kernel kernel_timestep_embedding;
439443
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
440444

@@ -1402,25 +1406,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14021406
}
14031407
}
14041408

1405-
// conv2d
1406-
{
1407-
#ifdef GGML_OPENCL_EMBED_KERNELS
1408-
const std::string kernel_src {
1409-
#include "conv2d.cl.h"
1410-
};
1411-
#else
1412-
const std::string kernel_src = read_file("conv2d.cl");
1413-
#endif
1414-
if (!kernel_src.empty()) {
1415-
backend_ctx->program_conv_2d =
1416-
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1417-
CL_CHECK((backend_ctx->kernel_conv_2d = clCreateKernel(backend_ctx->program_conv_2d, "kernel_conv_2d", &err), err));
1418-
GGML_LOG_CONT(".");
1419-
} else {
1420-
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1421-
backend_ctx->program_conv_2d = nullptr;
1422-
backend_ctx->kernel_conv_2d = nullptr;
1423-
}
1409+
// conv2d
1410+
{
1411+
#ifdef GGML_OPENCL_EMBED_KERNELS
1412+
const std::string kernel_src {
1413+
#include "conv2d.cl.h"
1414+
};
1415+
const std::string kernel_src_f16_f32 {
1416+
#include "conv2d_f16_f32.cl.h"
1417+
};
1418+
#else
1419+
const std::string kernel_src = read_file("conv2d.cl");
1420+
const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
1421+
#endif
1422+
if (!kernel_src.empty()) {
1423+
backend_ctx->program_conv_2d_f16 =
1424+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
1425+
CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
1426+
GGML_LOG_CONT(".");
1427+
backend_ctx->program_conv_2d_f32 =
1428+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1429+
CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
1430+
GGML_LOG_CONT(".");
1431+
} else {
1432+
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1433+
backend_ctx->program_conv_2d_f16 = nullptr;
1434+
backend_ctx->kernel_conv_2d_f16 = nullptr;
1435+
backend_ctx->program_conv_2d_f32 = nullptr;
1436+
backend_ctx->kernel_conv_2d_f32 = nullptr;
1437+
}
1438+
if (!kernel_src_f16_f32.empty()) {
1439+
backend_ctx->program_conv_2d_f16_f32 =
1440+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
1441+
CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
1442+
GGML_LOG_CONT(".");
1443+
} else {
1444+
GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
1445+
backend_ctx->program_conv_2d_f16_f32 = nullptr;
1446+
backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
1447+
}
14241448
}
14251449

14261450
// mul_mv_id_q4_0_f32_8x_flat
@@ -2279,7 +2303,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
22792303
case GGML_OP_UPSCALE:
22802304
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
22812305
case GGML_OP_CONV_2D:
2282-
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2306+
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2307+
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2308+
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
22832309
case GGML_OP_CONCAT:
22842310
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
22852311
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4729,25 +4755,44 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
47294755
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
47304756
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
47314757

4732-
const cl_uint cl_nb01 = nb01/nb00; const cl_uint cl_nb02 = nb02/nb00; const cl_uint cl_nb03 = nb03/nb00;
4733-
const cl_uint cl_nb11 = nb11/nb10; const cl_uint cl_nb12 = nb12/nb10; const cl_uint cl_nb13 = nb13/nb10;
4734-
const cl_uint cl_nb1 = nb1/nb0; const cl_uint cl_nb2 = nb2/nb0; const cl_uint cl_nb3 = nb3/nb0;
4758+
const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
4759+
const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
4760+
const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
47354761

47364762
const int64_t NPQ = (int64_t)N * OW * OH;
47374763

4738-
const uint32_t WG_SIZE = 128;
4739-
const uint32_t BS_K = 128;
4740-
const uint32_t BS_CRS = 16;
4764+
const uint32_t BS_K = 64;
47414765
const uint32_t BS_NPQ = 64;
4766+
const uint32_t BS_CRS = 16;
47424767
const uint32_t VEC_SIZE = 4;
47434768

4769+
const uint32_t TS_K = 4;
4770+
const uint32_t TS_NPQ = 8;
4771+
4772+
const uint32_t WG_K = BS_K / TS_K;
4773+
const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
4774+
47444775
auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
47454776
const uint32_t NB_K = splitWork(Cout, BS_K);
47464777
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
47474778

4748-
const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4));
4779+
cl_kernel kernel;
4780+
size_t shmem_size;
4781+
4782+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
4783+
kernel = backend_ctx->kernel_conv_2d_f16;
4784+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
4785+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
4786+
kernel = backend_ctx->kernel_conv_2d_f32;
4787+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
4788+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
4789+
kernel = backend_ctx->kernel_conv_2d_f16_f32;
4790+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
4791+
} else {
4792+
GGML_ASSERT(false && "Unsupported data type combination for conv2d");
4793+
return;
4794+
}
47494795

4750-
cl_kernel kernel = backend_ctx->kernel_conv_2d;
47514796
cl_uint idx = 0;
47524797
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
47534798
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
@@ -4762,18 +4807,18 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
47624807
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
47634808
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
47644809

4765-
size_t global_work_size[] = { (size_t)NB_K * WG_SIZE, (size_t)NB_NPQ, 1 };
4766-
size_t local_work_size[] = { (size_t)WG_SIZE, 1, 1 };
4810+
size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
4811+
size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
47674812

47684813
#ifdef GGML_OPENCL_PROFILING
47694814
cl_event evt;
4770-
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
4815+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, &evt));
47714816

47724817
backend_ctx->profiling_info.emplace_back();
4773-
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 3, global_work_size, local_work_size, dst);
4818+
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 2, global_work_size, local_work_size, dst);
47744819
#else
47754820
GGML_UNUSED(dst);
4776-
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
4821+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, NULL));
47774822
#endif
47784823
}
47794824

0 commit comments

Comments
 (0)