From b898521ee7d4bd5321821e04b48b3cb0bb312a96 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 26 Jun 2025 18:54:41 +0000 Subject: [PATCH 1/5] add conv2d kernel --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 92 ++++++++++++ ggml/src/ggml-opencl/kernels/conv2d.cl | 193 +++++++++++++++++++++++++ 3 files changed, 286 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/conv2d.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ec5d8cf59556b..b51607dd3175a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -105,6 +105,7 @@ set(GGML_OPENCL_KERNELS pad repeat mul_mat_f16_f32 + conv2d ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3388259152b46..0d01a5594a152 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -390,6 +390,7 @@ struct ggml_backend_opencl_context { cl_program program_tanh; cl_program program_upscale; cl_program program_concat; + cl_program program_conv_2d; cl_program program_tsembd; cl_program program_mul_mv_id_q4_0_f32_8x_flat; @@ -441,6 +442,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32_contiguous; cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_conv_2d; cl_kernel kernel_timestep_embedding; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1478,6 +1480,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // conv2d + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "conv2d.cl.h" + }; +#else + const std::string kernel_src = read_file("conv2d.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_conv_2d = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d = clCreateKernel(backend_ctx->program_conv_2d, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d = nullptr; + backend_ctx->kernel_conv_2d = nullptr; + } + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2361,6 +2384,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te op->src[0]->ne[3] == 1 && op->ne[3] == 1; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -4946,7 +4971,12 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); } +<<<<<<< HEAD static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +======= +static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS; +>>>>>>> 4d5d5a83 (add conv2d kernel) ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; @@ -4957,6 +4987,7 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; +<<<<<<< HEAD const int M = src0->ne[1]; const int N = src1->ne[1]; const int K = src0->ne[0]; @@ -4996,6 +5027,61 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten }; backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); +======= + const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; + const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; + + const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1]; + const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; + const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; + + const cl_uint cl_nb01 = nb01/nb00; const cl_uint cl_nb02 = nb02/nb00; const cl_uint cl_nb03 = nb03/nb00; + const cl_uint cl_nb11 = nb11/nb10; const cl_uint cl_nb12 = nb12/nb10; const cl_uint cl_nb13 = nb13/nb10; + const cl_uint cl_nb1 = nb1/nb0; const cl_uint cl_nb2 = nb2/nb0; const cl_uint cl_nb3 = nb3/nb0; + + const int64_t NPQ = (int64_t)N * OW * OH; + + const uint32_t WG_SIZE = 128; + const uint32_t BS_K = 128; + const uint32_t BS_CRS = 16; + const uint32_t BS_NPQ = 64; + const uint32_t VEC_SIZE = 4; + + auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; + const uint32_t NB_K = splitWork(Cout, BS_K); + const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); + + const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4)); + + cl_kernel kernel = backend_ctx->kernel_conv_2d; + cl_uint idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); + + size_t global_work_size[] = { (size_t)NB_K * WG_SIZE, (size_t)NB_NPQ, 1 }; + size_t local_work_size[] = { (size_t)WG_SIZE, 1, 1 }; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + backend_ctx->profiling_info.emplace_back(); + populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 3, global_work_size, local_work_size, dst); +#else + GGML_UNUSED(dst); + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +>>>>>>> 4d5d5a83 (add conv2d kernel) } static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -6752,6 +6838,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } ggml_cl_upscale(backend, tensor->src[0], tensor); return true; + case GGML_OP_CONV_2D: + if (!any_on_device) { + return false; + } + func = ggml_cl_conv_2d; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl new file mode 100644 index 0000000000000..8bebc44bf2124 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d.cl @@ -0,0 +1,193 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_64 +#endif + +#define T_FLOAT half +#define T_FLOAT4 half4 +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 128 +#define BS_CRS 16 +#define BS_NPQ 64 +#define TS_K 8 +#define TS_NPQ 8 +#define WG_SIZE 128 + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +#define NT_K (BS_K / TS_K) +#define NT_NPQ (BS_NPQ / TS_NPQ) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_64 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global float* knl_data = (global float*) ((global char*)p_knl + off_knl); + global float* src_data = (global float*) ((global char*)p_src + off_src); + global float* dst_data = (global float*) ((global char*)p_dst + off_dst); + + const uint tid = get_local_id(0); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + const uint Ash_stride = BS_CRS + 1; + const uint Bsh_stride_vec = BS_NPQ_VEC + 1; + + local T_FLOAT* Ash = (local T_FLOAT*)shared; + local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * Ash_stride]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint T_y = tid / NT_NPQ; + const uint T_x = tid % NT_NPQ; + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + for(uint i = tid; i < BS_K * BS_CRS; i += WG_SIZE){ + uint k_l = i / BS_CRS; + uint crs_l = i % BS_CRS; + uint k_g = B_idx_K*BS_K + k_l; + uint crs_g = B_idx_CRS*BS_CRS + crs_l; + if(k_g < K && crs_g < CRS){ + uint Cin_idx = crs_g / (KW*KH); + uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)knl_data[knl_idx]; + } else { + Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)0.0h; + } + } + + for (uint i = tid; i < BS_CRS * BS_NPQ_VEC; i += WG_SIZE) { + uint crs_l = i / BS_NPQ_VEC; + uint npq_l_vec = i % BS_NPQ_VEC; + + float4 val_f = (float4)(0.0f); + uint crs_g = B_idx_CRS * BS_CRS + crs_l; + + if (crs_g < CRS) { + uint Cin_idx = crs_g / (KW * KH); + uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + + for (int v = 0; v < VEC_SIZE; ++v) { + uint npq_g = B_idx_NPQ * BS_NPQ + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + uint N_idx = npq_g / (OH * OW); + uint pq_idx = npq_g % (OH * OW); + uint OH_idx = pq_idx / OW; + uint OW_idx = pq_idx % OW; + int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + switch (v) { + case 0: val_f.s0 = src_data[src_idx]; break; + case 1: val_f.s1 = src_data[src_idx]; break; + case 2: val_f.s2 = src_data[src_idx]; break; + case 3: val_f.s3 = src_data[src_idx]; break; + } + } + } + } + } + Bsh[crs_l * Bsh_stride_vec + npq_l_vec] = convert_half4(val_f); + } + barrier(CLK_LOCAL_MEM_FENCE); + + for(uint crs_l = 0; crs_l < BS_CRS; ++crs_l){ + T_FLOAT regA[TS_K]; + for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ + regA[k_l_reg] = Ash[(T_y*TS_K + k_l_reg)*Ash_stride + crs_l]; + } + for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){ + T_FLOAT4 regB = Bsh[crs_l*Bsh_stride_vec + T_x*TS_NPQ_VEC + npq_l_vec_reg]; + for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ + regC[k_l_reg][npq_l_vec_reg] = mad((float)regA[k_l_reg], convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ + uint k_g = B_idx_K * BS_K + T_y * TS_K + k_l_reg; + if(k_g >= K) continue; + + for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){ + uint npq_g_base = B_idx_NPQ * BS_NPQ + (T_x * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + uint N_idx = npq_g_base / (OH*OW); + uint pq_idx = npq_g_base % (OH*OW); + uint OH_idx = pq_idx / OW; + uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + uint N_idx_s = npq_g / (OH*OW); + uint pq_idx_s = npq_g % (OH*OW); + uint OH_idx_s = pq_idx_s / OW; + uint OW_idx_s = pq_idx_s % OW; + uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + float val_f; + switch(v) { + case 0: val_f = res.s0; break; + case 1: val_f = res.s1; break; + case 2: val_f = res.s2; break; + default:val_f = res.s3; break; + } + dst_data[dst_idx_s] = val_f; + } + } + } + } + } +} From bc3cd91e12a2c206f037c447a167e076c8a15e73 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 26 Jun 2025 19:20:27 +0000 Subject: [PATCH 2/5] fix trailing whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0d01a5594a152..763d9445ad3e0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6926,3 +6926,4 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor func(backend, tensor->src[0], tensor->src[1], tensor); return true; } + From f555aa3076a56e059e3072eb8e353b732173c5d5 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 26 Jun 2025 19:22:31 +0000 Subject: [PATCH 3/5] whitespace fixe --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 763d9445ad3e0..8b13bea77e7d7 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5050,7 +5050,7 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; const uint32_t NB_K = splitWork(Cout, BS_K); const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); - + const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4)); cl_kernel kernel = backend_ctx->kernel_conv_2d; From 8412441e0612ca5663fe3abd73a09eca62224dd4 Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 16 Jul 2025 18:10:48 +0000 Subject: [PATCH 4/5] handle f16 input and f16 kernel, more opt --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 115 ++++++---- ggml/src/ggml-opencl/kernels/conv2d.cl | 196 +++++++++--------- .../src/ggml-opencl/kernels/conv2d_f16_f32.cl | 176 ++++++++++++++++ 4 files changed, 351 insertions(+), 137 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index b51607dd3175a..015fa8f06824e 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -106,6 +106,7 @@ set(GGML_OPENCL_KERNELS repeat mul_mat_f16_f32 conv2d + conv2d_f16_f32 ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8b13bea77e7d7..46d38ff913b91 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -390,7 +390,9 @@ struct ggml_backend_opencl_context { cl_program program_tanh; cl_program program_upscale; cl_program program_concat; - cl_program program_conv_2d; + cl_program program_conv_2d_f16; + cl_program program_conv_2d_f32; + cl_program program_conv_2d_f16_f32; cl_program program_tsembd; cl_program program_mul_mv_id_q4_0_f32_8x_flat; @@ -442,7 +444,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32_contiguous; cl_kernel kernel_concat_f32_non_contiguous; - cl_kernel kernel_conv_2d; + cl_kernel kernel_conv_2d_f16; + cl_kernel kernel_conv_2d_f32; + cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_timestep_embedding; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1480,25 +1484,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // conv2d - { -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "conv2d.cl.h" - }; -#else - const std::string kernel_src = read_file("conv2d.cl"); -#endif - if (!kernel_src.empty()) { - backend_ctx->program_conv_2d = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_conv_2d = clCreateKernel(backend_ctx->program_conv_2d, "kernel_conv_2d", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); - backend_ctx->program_conv_2d = nullptr; - backend_ctx->kernel_conv_2d = nullptr; - } + // conv2d + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "conv2d.cl.h" + }; + const std::string kernel_src_f16_f32 { + #include "conv2d_f16_f32.cl.h" + }; + #else + const std::string kernel_src = read_file("conv2d.cl"); + const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl"); + #endif + if (!kernel_src.empty()) { + backend_ctx->program_conv_2d_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str()); + CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + backend_ctx->program_conv_2d_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16 = nullptr; + backend_ctx->kernel_conv_2d_f16 = nullptr; + backend_ctx->program_conv_2d_f32 = nullptr; + backend_ctx->kernel_conv_2d_f32 = nullptr; + } + if (!kernel_src_f16_f32.empty()) { + backend_ctx->program_conv_2d_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16_f32 = nullptr; + backend_ctx->kernel_conv_2d_f16_f32 = nullptr; + } } // mul_mv_id_q4_0_f32_8x_flat @@ -2385,7 +2409,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || + (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -5035,25 +5061,44 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; - const cl_uint cl_nb01 = nb01/nb00; const cl_uint cl_nb02 = nb02/nb00; const cl_uint cl_nb03 = nb03/nb00; - const cl_uint cl_nb11 = nb11/nb10; const cl_uint cl_nb12 = nb12/nb10; const cl_uint cl_nb13 = nb13/nb10; - const cl_uint cl_nb1 = nb1/nb0; const cl_uint cl_nb2 = nb2/nb0; const cl_uint cl_nb3 = nb3/nb0; + const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type); + const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type); + const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type); const int64_t NPQ = (int64_t)N * OW * OH; - const uint32_t WG_SIZE = 128; - const uint32_t BS_K = 128; - const uint32_t BS_CRS = 16; + const uint32_t BS_K = 64; const uint32_t BS_NPQ = 64; + const uint32_t BS_CRS = 16; const uint32_t VEC_SIZE = 4; + const uint32_t TS_K = 4; + const uint32_t TS_NPQ = 8; + + const uint32_t WG_K = BS_K / TS_K; + const uint32_t WG_NPQ = BS_NPQ / TS_NPQ; + auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; const uint32_t NB_K = splitWork(Cout, BS_K); const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); - const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4)); + cl_kernel kernel; + size_t shmem_size; + + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_conv_2d_f16; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f16_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else { + GGML_ASSERT(false && "Unsupported data type combination for conv2d"); + return; + } - cl_kernel kernel = backend_ctx->kernel_conv_2d; cl_uint idx = 0; CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); @@ -5068,18 +5113,18 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); - size_t global_work_size[] = { (size_t)NB_K * WG_SIZE, (size_t)NB_NPQ, 1 }; - size_t local_work_size[] = { (size_t)WG_SIZE, 1, 1 }; + size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 }; + size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 }; #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, &evt)); backend_ctx->profiling_info.emplace_back(); - populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 3, global_work_size, local_work_size, dst); + populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 2, global_work_size, local_work_size, dst); #else GGML_UNUSED(dst); - CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, NULL)); #endif >>>>>>> 4d5d5a83 (add conv2d kernel) } diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl index 8bebc44bf2124..e339c90cff59f 100644 --- a/ggml/src/ggml-opencl/kernels/conv2d.cl +++ b/ggml/src/ggml-opencl/kernels/conv2d.cl @@ -1,42 +1,42 @@ +#ifdef USE_FP16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define T_FLOAT half +#define T_FLOAT4 half4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p) +#else +#define T_FLOAT float +#define T_FLOAT4 float4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p) +#endif -#ifdef cl_intel_required_subgroup_size -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) +#if defined(cl_qcom_reqd_sub_group_size) #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) #else -#define REQD_SUBGROUP_SIZE_64 +#define REQD_SUBGROUP_SIZE_128 #endif -#define T_FLOAT half -#define T_FLOAT4 half4 #define T_ACCUM float4 #define VEC_SIZE 4 -#define BS_K 128 -#define BS_CRS 16 +#define BS_K 64 #define BS_NPQ 64 -#define TS_K 8 +#define BS_CRS 16 + +#define TS_K 4 #define TS_NPQ 8 -#define WG_SIZE 128 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) #define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) #define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) -#define NT_K (BS_K / TS_K) -#define NT_NPQ (BS_NPQ / TS_NPQ) - static inline uint splitWork(uint work_size, uint block_size){ return (work_size + block_size - 1) / block_size; } -REQD_SUBGROUP_SIZE_64 +REQD_SUBGROUP_SIZE_128 kernel void kernel_conv_2d( global void* p_knl, ulong off_knl, @@ -52,23 +52,26 @@ kernel void kernel_conv_2d( uint nb11, uint nb12, uint nb13, uint nb1, uint nb2, uint nb3 ) { - global float* knl_data = (global float*) ((global char*)p_knl + off_knl); - global float* src_data = (global float*) ((global char*)p_src + off_src); - global float* dst_data = (global float*) ((global char*)p_dst + off_dst); - - const uint tid = get_local_id(0); + global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl); + global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src); + global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst); const uint K = Cout; const uint CRS = Cin*KH*KW; const uint NPQ = N*OH*OW; - const uint NB_CRS = splitWork(CRS, BS_CRS); + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; - const uint Ash_stride = BS_CRS + 1; - const uint Bsh_stride_vec = BS_NPQ_VEC + 1; + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; local T_FLOAT* Ash = (local T_FLOAT*)shared; - local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * Ash_stride]; + local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS]; T_ACCUM regC[TS_K][TS_NPQ_VEC]; for (int i = 0; i < TS_K; ++i) { @@ -77,114 +80,103 @@ kernel void kernel_conv_2d( } } - const uint B_idx_K = get_group_id(0); - const uint B_idx_NPQ = get_group_id(1); - - const uint T_y = tid / NT_NPQ; - const uint T_x = tid % NT_NPQ; + const uint NB_CRS = splitWork(CRS, BS_CRS); for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { - for(uint i = tid; i < BS_K * BS_CRS; i += WG_SIZE){ - uint k_l = i / BS_CRS; - uint crs_l = i % BS_CRS; - uint k_g = B_idx_K*BS_K + k_l; - uint crs_g = B_idx_CRS*BS_CRS + crs_l; - if(k_g < K && crs_g < CRS){ - uint Cin_idx = crs_g / (KW*KH); - uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; - uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; - uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; - Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)knl_data[knl_idx]; + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; } else { - Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)0.0h; + Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f; } } - for (uint i = tid; i < BS_CRS * BS_NPQ_VEC; i += WG_SIZE) { - uint crs_l = i / BS_NPQ_VEC; - uint npq_l_vec = i % BS_NPQ_VEC; - - float4 val_f = (float4)(0.0f); - uint crs_g = B_idx_CRS * BS_CRS + crs_l; + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + T_FLOAT4 val = (T_FLOAT4)(0.0f); if (crs_g < CRS) { - uint Cin_idx = crs_g / (KW * KH); - uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; - uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; - + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; for (int v = 0; v < VEC_SIZE; ++v) { - uint npq_g = B_idx_NPQ * BS_NPQ + npq_l_vec * VEC_SIZE + v; + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; if (npq_g < NPQ) { - uint N_idx = npq_g / (OH * OW); - uint pq_idx = npq_g % (OH * OW); - uint OH_idx = pq_idx / OW; - uint OW_idx = pq_idx % OW; - int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); - int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { - uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; - switch (v) { - case 0: val_f.s0 = src_data[src_idx]; break; - case 1: val_f.s1 = src_data[src_idx]; break; - case 2: val_f.s2 = src_data[src_idx]; break; - case 3: val_f.s3 = src_data[src_idx]; break; - } + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((T_FLOAT*)&val)[v] = src_data[src_idx]; } } } } - Bsh[crs_l * Bsh_stride_vec + npq_l_vec] = convert_half4(val_f); + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; } + barrier(CLK_LOCAL_MEM_FENCE); - for(uint crs_l = 0; crs_l < BS_CRS; ++crs_l){ + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { T_FLOAT regA[TS_K]; - for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ - regA[k_l_reg] = Ash[(T_y*TS_K + k_l_reg)*Ash_stride + crs_l]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; } - for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){ - T_FLOAT4 regB = Bsh[crs_l*Bsh_stride_vec + T_x*TS_NPQ_VEC + npq_l_vec_reg]; - for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ - regC[k_l_reg][npq_l_vec_reg] = mad((float)regA[k_l_reg], convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]); + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]); } } } barrier(CLK_LOCAL_MEM_FENCE); } - for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){ - uint k_g = B_idx_K * BS_K + T_y * TS_K + k_l_reg; - if(k_g >= K) continue; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; - for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){ - uint npq_g_base = B_idx_NPQ * BS_NPQ + (T_x * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; - uint N_idx = npq_g_base / (OH*OW); - uint pq_idx = npq_g_base % (OH*OW); - uint OH_idx = pq_idx / OW; - uint OW_idx = pq_idx % OW; + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { - uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; - vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); } else { T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; for (int v = 0; v < VEC_SIZE; ++v) { - uint npq_g = npq_g_base + v; + const uint npq_g = npq_g_base + v; if (npq_g < NPQ) { - uint N_idx_s = npq_g / (OH*OW); - uint pq_idx_s = npq_g % (OH*OW); - uint OH_idx_s = pq_idx_s / OW; - uint OW_idx_s = pq_idx_s % OW; - uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; - float val_f; - switch(v) { - case 0: val_f = res.s0; break; - case 1: val_f = res.s1; break; - case 2: val_f = res.s2; break; - default:val_f = res.s3; break; - } - dst_data[dst_idx_s] = val_f; + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]); } } } diff --git a/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl new file mode 100644 index 0000000000000..cb05637f33ac8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#if defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_128 +#endif + +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 64 +#define BS_NPQ 64 +#define BS_CRS 16 + +#define TS_K 4 +#define TS_NPQ 8 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_128 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global half* knl_data = (global half*) ((global char*)p_knl + off_knl); + global float* src_data = (global float*) ((global char*)p_src + off_src); + global float* dst_data = (global float*) ((global char*)p_dst + off_dst); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; + + local half* Ash = (local half*)shared; + local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; + } else { + Ash[k_l * BS_CRS + crs_l] = (half)0.0f; + } + } + + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + + float4 val = (float4)(0.0f); + if (crs_g < CRS) { + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((float*)&val)[v] = src_data[src_idx]; + } + } + } + } + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { + half regA[TS_K]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; + } + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = ((float*)&res)[v]; + } + } + } + } + } +} From 98c65715ac7e05f19e737039a310fdbf34fccd0d Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 17 Jul 2025 11:06:44 +0000 Subject: [PATCH 5/5] resolve conflicts --- ggml/src/ggml-opencl/ggml-opencl.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 46d38ff913b91..495d1d4014c20 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4997,12 +4997,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); } -<<<<<<< HEAD static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { -======= -static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_TENSOR_BINARY_OP_LOCALS; ->>>>>>> 4d5d5a83 (add conv2d kernel) ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; @@ -5013,7 +5008,6 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; -<<<<<<< HEAD const int M = src0->ne[1]; const int N = src1->ne[1]; const int K = src0->ne[0]; @@ -5053,7 +5047,20 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co }; backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); -======= +} + +static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; @@ -5126,7 +5133,6 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co GGML_UNUSED(dst); CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, NULL)); #endif ->>>>>>> 4d5d5a83 (add conv2d kernel) } static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -6971,4 +6977,3 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor func(backend, tensor->src[0], tensor->src[1], tensor); return true; } -