Skip to content

OpenCL: add conv2d kernel #14403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS
pad
repeat
mul_mat_f16_f32
conv2d
conv2d_f16_f32
)

foreach (K ${GGML_OPENCL_KERNELS})
Expand Down
143 changes: 143 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
cl_program program_tanh;
cl_program program_upscale;
cl_program program_concat;
cl_program program_conv_2d_f16;
cl_program program_conv_2d_f32;
cl_program program_conv_2d_f16_f32;
cl_program program_tsembd;
cl_program program_mul_mv_id_q4_0_f32_8x_flat;

Expand Down Expand Up @@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous;
cl_kernel kernel_concat_f32_non_contiguous;
cl_kernel kernel_conv_2d_f16;
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;

Expand Down Expand Up @@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}

// conv2d
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "conv2d.cl.h"
};
const std::string kernel_src_f16_f32 {
#include "conv2d_f16_f32.cl.h"
};
#else
const std::string kernel_src = read_file("conv2d.cl");
const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
#endif
if (!kernel_src.empty()) {
backend_ctx->program_conv_2d_f16 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
GGML_LOG_CONT(".");
backend_ctx->program_conv_2d_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
backend_ctx->program_conv_2d_f16 = nullptr;
backend_ctx->kernel_conv_2d_f16 = nullptr;
backend_ctx->program_conv_2d_f32 = nullptr;
backend_ctx->kernel_conv_2d_f32 = nullptr;
}
if (!kernel_src_f16_f32.empty()) {
backend_ctx->program_conv_2d_f16_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
backend_ctx->program_conv_2d_f16_f32 = nullptr;
backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
}
}

// mul_mv_id_q4_0_f32_8x_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
Expand Down Expand Up @@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_CONCAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TIMESTEP_EMBEDDING:
Expand Down Expand Up @@ -4998,6 +5049,92 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}

static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_TENSOR_BINARY_OP_LOCALS;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;

cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;

const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;

const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];

const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);

const int64_t NPQ = (int64_t)N * OW * OH;

const uint32_t BS_K = 64;
const uint32_t BS_NPQ = 64;
const uint32_t BS_CRS = 16;
const uint32_t VEC_SIZE = 4;

const uint32_t TS_K = 4;
const uint32_t TS_NPQ = 8;

const uint32_t WG_K = BS_K / TS_K;
const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;

auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
const uint32_t NB_K = splitWork(Cout, BS_K);
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);

cl_kernel kernel;
size_t shmem_size;

if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_conv_2d_f16;
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_conv_2d_f32;
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_conv_2d_f16_f32;
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
} else {
GGML_ASSERT(false && "Unsupported data type combination for conv2d");
return;
}

cl_uint idx = 0;
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));

size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };

#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, &evt));

backend_ctx->profiling_info.emplace_back();
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 2, global_work_size, local_work_size, dst);
#else
GGML_UNUSED(dst);
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, NULL));
#endif
}

static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
Expand Down Expand Up @@ -6752,6 +6889,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
ggml_cl_upscale(backend, tensor->src[0], tensor);
return true;
case GGML_OP_CONV_2D:
if (!any_on_device) {
return false;
}
func = ggml_cl_conv_2d;
break;
case GGML_OP_CONCAT:
if (!any_on_device) {
return false;
Expand Down
185 changes: 185 additions & 0 deletions ggml/src/ggml-opencl/kernels/conv2d.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#ifdef USE_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define T_FLOAT half
#define T_FLOAT4 half4
#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
#else
#define T_FLOAT float
#define T_FLOAT4 float4
#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
#endif

#if defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#else
#define REQD_SUBGROUP_SIZE_128
#endif

#define T_ACCUM float4
#define VEC_SIZE 4

#define BS_K 64
#define BS_NPQ 64
#define BS_CRS 16

#define TS_K 4
#define TS_NPQ 8

#define WG_K (BS_K / TS_K)
#define WG_NPQ (BS_NPQ / TS_NPQ)

#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)

static inline uint splitWork(uint work_size, uint block_size){
return (work_size + block_size - 1) / block_size;
}

REQD_SUBGROUP_SIZE_128
kernel void kernel_conv_2d(
global void* p_knl,
ulong off_knl,
global void* p_src,
ulong off_src,
global void* p_dst,
ulong off_dst,
local void* shared,
uint Cout, uint Cin, uint N,
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
uint nb01, uint nb02, uint nb03,
uint nb11, uint nb12, uint nb13,
uint nb1, uint nb2, uint nb3
) {
global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);

const uint K = Cout;
const uint CRS = Cin*KH*KW;
const uint NPQ = N*OH*OW;

const uint lid_k = get_local_id(0);
const uint lid_npq = get_local_id(1);
const uint tid = lid_npq * WG_K + lid_k;

const uint B_idx_K = get_group_id(0);
const uint B_idx_NPQ = get_group_id(1);

const uint offset_k = B_idx_K * BS_K;
const uint offset_npq = B_idx_NPQ * BS_NPQ;

local T_FLOAT* Ash = (local T_FLOAT*)shared;
local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];

T_ACCUM regC[TS_K][TS_NPQ_VEC];
for (int i = 0; i < TS_K; ++i) {
for (int j = 0; j < TS_NPQ_VEC; ++j) {
regC[i][j] = (T_ACCUM)(0.0f);
}
}

const uint NB_CRS = splitWork(CRS, BS_CRS);

for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
const uint offset_crs = B_idx_CRS * BS_CRS;

for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
const uint k_l = i / BS_CRS;
const uint crs_l = i % BS_CRS;
const uint k_g = offset_k + k_l;
const uint crs_g = offset_crs + crs_l;

if (k_g < K && crs_g < CRS) {
const uint Cin_idx = crs_g / (KW*KH);
const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
} else {
Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
}
}

for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
const uint crs_l = i / BS_NPQ_VEC;
const uint npq_l_vec = i % BS_NPQ_VEC;
const uint crs_g = offset_crs + crs_l;

T_FLOAT4 val = (T_FLOAT4)(0.0f);
if (crs_g < CRS) {
const uint Cin_idx = crs_g / (KW * KH);
const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
for (int v = 0; v < VEC_SIZE; ++v) {
const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
if (npq_g < NPQ) {
const uint N_idx = npq_g / (OH * OW);
const uint pq_idx = npq_g % (OH * OW);
const uint OH_idx = pq_idx / OW;
const uint OW_idx = pq_idx % OW;
const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);

if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
((T_FLOAT*)&val)[v] = src_data[src_idx];
}
}
}
}
Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
}

barrier(CLK_LOCAL_MEM_FENCE);

#pragma unroll
for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
T_FLOAT regA[TS_K];
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
}

for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}

for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
if (k_g >= K) continue;

for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;

const uint N_idx = npq_g_base / (OH * OW);
const uint pq_idx = npq_g_base % (OH * OW);
const uint OH_idx = pq_idx / OW;
const uint OW_idx = pq_idx % OW;

if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
} else {
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
for (int v = 0; v < VEC_SIZE; ++v) {
const uint npq_g = npq_g_base + v;
if (npq_g < NPQ) {
const uint N_idx_s = npq_g / (OH*OW);
const uint pq_idx_s = npq_g % (OH*OW);
const uint OH_idx_s = pq_idx_s / OW;
const uint OW_idx_s = pq_idx_s % OW;
const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
}
}
}
}
}
}
Loading
Loading