Skip to content

Commit 4d5d5a8

Browse files
committed
add conv2d kernel
1 parent 8846aac commit 4d5d5a8

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ set(GGML_OPENCL_KERNELS
102102
tanh
103103
pad
104104
repeat
105+
conv2d
105106
)
106107

107108
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ struct ggml_backend_opencl_context {
387387
cl_program program_tanh;
388388
cl_program program_upscale;
389389
cl_program program_concat;
390+
cl_program program_conv_2d;
390391
cl_program program_tsembd;
391392
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
392393

@@ -433,6 +434,7 @@ struct ggml_backend_opencl_context {
433434
cl_kernel kernel_upscale_bilinear;
434435
cl_kernel kernel_concat_f32_contiguous;
435436
cl_kernel kernel_concat_f32_non_contiguous;
437+
cl_kernel kernel_conv_2d;
436438
cl_kernel kernel_timestep_embedding;
437439
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
438440

@@ -1400,6 +1402,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14001402
}
14011403
}
14021404

1405+
// conv2d
1406+
{
1407+
#ifdef GGML_OPENCL_EMBED_KERNELS
1408+
const std::string kernel_src {
1409+
#include "conv2d.cl.h"
1410+
};
1411+
#else
1412+
const std::string kernel_src = read_file("conv2d.cl");
1413+
#endif
1414+
if (!kernel_src.empty()) {
1415+
backend_ctx->program_conv_2d =
1416+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1417+
CL_CHECK((backend_ctx->kernel_conv_2d = clCreateKernel(backend_ctx->program_conv_2d, "kernel_conv_2d", &err), err));
1418+
GGML_LOG_CONT(".");
1419+
} else {
1420+
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1421+
backend_ctx->program_conv_2d = nullptr;
1422+
backend_ctx->kernel_conv_2d = nullptr;
1423+
}
1424+
}
1425+
14031426
// mul_mv_id_q4_0_f32_8x_flat
14041427
{
14051428
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2255,6 +2278,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
22552278
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
22562279
case GGML_OP_UPSCALE:
22572280
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2281+
case GGML_OP_CONV_2D:
2282+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
22582283
case GGML_OP_CONCAT:
22592284
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
22602285
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4685,6 +4710,73 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
46854710
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
46864711
}
46874712

4713+
static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4714+
GGML_TENSOR_BINARY_OP_LOCALS;
4715+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4716+
4717+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4718+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4719+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4720+
4721+
cl_ulong offset0 = extra0->offset + src0->view_offs;
4722+
cl_ulong offset1 = extra1->offset + src1->view_offs;
4723+
cl_ulong offsetd = extrad->offset + dst->view_offs;
4724+
4725+
const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
4726+
const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
4727+
4728+
const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
4729+
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
4730+
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
4731+
4732+
const cl_uint cl_nb01 = nb01/nb00; const cl_uint cl_nb02 = nb02/nb00; const cl_uint cl_nb03 = nb03/nb00;
4733+
const cl_uint cl_nb11 = nb11/nb10; const cl_uint cl_nb12 = nb12/nb10; const cl_uint cl_nb13 = nb13/nb10;
4734+
const cl_uint cl_nb1 = nb1/nb0; const cl_uint cl_nb2 = nb2/nb0; const cl_uint cl_nb3 = nb3/nb0;
4735+
4736+
const int64_t NPQ = (int64_t)N * OW * OH;
4737+
4738+
const uint32_t WG_SIZE = 128;
4739+
const uint32_t BS_K = 128;
4740+
const uint32_t BS_CRS = 16;
4741+
const uint32_t BS_NPQ = 64;
4742+
const uint32_t VEC_SIZE = 4;
4743+
4744+
auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
4745+
const uint32_t NB_K = splitWork(Cout, BS_K);
4746+
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
4747+
4748+
const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4));
4749+
4750+
cl_kernel kernel = backend_ctx->kernel_conv_2d;
4751+
cl_uint idx = 0;
4752+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
4753+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
4754+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
4755+
CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
4756+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
4757+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
4758+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
4759+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
4760+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
4761+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
4762+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
4763+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
4764+
4765+
size_t global_work_size[] = { (size_t)NB_K * WG_SIZE, (size_t)NB_NPQ, 1 };
4766+
size_t local_work_size[] = { (size_t)WG_SIZE, 1, 1 };
4767+
4768+
#ifdef GGML_OPENCL_PROFILING
4769+
cl_event evt;
4770+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
4771+
4772+
backend_ctx->profiling_info.emplace_back();
4773+
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 3, global_work_size, local_work_size, dst);
4774+
#else
4775+
GGML_UNUSED(dst);
4776+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
4777+
#endif
4778+
}
4779+
46884780
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
46894781
GGML_ASSERT(src0);
46904782
GGML_ASSERT(src0->extra);
@@ -6286,6 +6378,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
62866378
}
62876379
ggml_cl_upscale(backend, tensor->src[0], tensor);
62886380
return true;
6381+
case GGML_OP_CONV_2D:
6382+
if (!any_on_device) {
6383+
return false;
6384+
}
6385+
func = ggml_cl_conv_2d;
6386+
break;
62896387
case GGML_OP_CONCAT:
62906388
if (!any_on_device) {
62916389
return false;
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#ifdef cl_intel_required_subgroup_size
4+
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
5+
#define INTEL_GPU 1
6+
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
7+
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
8+
#elif defined(cl_qcom_reqd_sub_group_size)
9+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
10+
#define ADRENO_GPU 1
11+
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
12+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
13+
#else
14+
#define REQD_SUBGROUP_SIZE_64
15+
#endif
16+
17+
#define T_FLOAT half
18+
#define T_FLOAT4 half4
19+
#define T_ACCUM float4
20+
#define VEC_SIZE 4
21+
22+
#define BS_K 128
23+
#define BS_CRS 16
24+
#define BS_NPQ 64
25+
#define TS_K 8
26+
#define TS_NPQ 8
27+
#define WG_SIZE 128
28+
29+
#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
30+
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
31+
32+
#define NT_K (BS_K / TS_K)
33+
#define NT_NPQ (BS_NPQ / TS_NPQ)
34+
35+
static inline uint splitWork(uint work_size, uint block_size){
36+
return (work_size + block_size - 1) / block_size;
37+
}
38+
39+
REQD_SUBGROUP_SIZE_64
40+
kernel void kernel_conv_2d(
41+
global void* p_knl,
42+
ulong off_knl,
43+
global void* p_src,
44+
ulong off_src,
45+
global void* p_dst,
46+
ulong off_dst,
47+
local void* shared,
48+
uint Cout, uint Cin, uint N,
49+
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
50+
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
51+
uint nb01, uint nb02, uint nb03,
52+
uint nb11, uint nb12, uint nb13,
53+
uint nb1, uint nb2, uint nb3
54+
) {
55+
global float* knl_data = (global float*) ((global char*)p_knl + off_knl);
56+
global float* src_data = (global float*) ((global char*)p_src + off_src);
57+
global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
58+
59+
const uint tid = get_local_id(0);
60+
61+
const uint K = Cout;
62+
const uint CRS = Cin*KH*KW;
63+
const uint NPQ = N*OH*OW;
64+
65+
const uint NB_CRS = splitWork(CRS, BS_CRS);
66+
67+
const uint Ash_stride = BS_CRS + 1;
68+
const uint Bsh_stride_vec = BS_NPQ_VEC + 1;
69+
70+
local T_FLOAT* Ash = (local T_FLOAT*)shared;
71+
local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * Ash_stride];
72+
73+
T_ACCUM regC[TS_K][TS_NPQ_VEC];
74+
for (int i = 0; i < TS_K; ++i) {
75+
for (int j = 0; j < TS_NPQ_VEC; ++j) {
76+
regC[i][j] = (T_ACCUM)(0.0f);
77+
}
78+
}
79+
80+
const uint B_idx_K = get_group_id(0);
81+
const uint B_idx_NPQ = get_group_id(1);
82+
83+
const uint T_y = tid / NT_NPQ;
84+
const uint T_x = tid % NT_NPQ;
85+
86+
for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
87+
for(uint i = tid; i < BS_K * BS_CRS; i += WG_SIZE){
88+
uint k_l = i / BS_CRS;
89+
uint crs_l = i % BS_CRS;
90+
uint k_g = B_idx_K*BS_K + k_l;
91+
uint crs_g = B_idx_CRS*BS_CRS + crs_l;
92+
if(k_g < K && crs_g < CRS){
93+
uint Cin_idx = crs_g / (KW*KH);
94+
uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
95+
uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
96+
uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
97+
Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)knl_data[knl_idx];
98+
} else {
99+
Ash[k_l * Ash_stride + crs_l] = (T_FLOAT)0.0h;
100+
}
101+
}
102+
103+
for (uint i = tid; i < BS_CRS * BS_NPQ_VEC; i += WG_SIZE) {
104+
uint crs_l = i / BS_NPQ_VEC;
105+
uint npq_l_vec = i % BS_NPQ_VEC;
106+
107+
float4 val_f = (float4)(0.0f);
108+
uint crs_g = B_idx_CRS * BS_CRS + crs_l;
109+
110+
if (crs_g < CRS) {
111+
uint Cin_idx = crs_g / (KW * KH);
112+
uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
113+
uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
114+
115+
for (int v = 0; v < VEC_SIZE; ++v) {
116+
uint npq_g = B_idx_NPQ * BS_NPQ + npq_l_vec * VEC_SIZE + v;
117+
if (npq_g < NPQ) {
118+
uint N_idx = npq_g / (OH * OW);
119+
uint pq_idx = npq_g % (OH * OW);
120+
uint OH_idx = pq_idx / OW;
121+
uint OW_idx = pq_idx % OW;
122+
int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
123+
int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
124+
125+
if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
126+
uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
127+
switch (v) {
128+
case 0: val_f.s0 = src_data[src_idx]; break;
129+
case 1: val_f.s1 = src_data[src_idx]; break;
130+
case 2: val_f.s2 = src_data[src_idx]; break;
131+
case 3: val_f.s3 = src_data[src_idx]; break;
132+
}
133+
}
134+
}
135+
}
136+
}
137+
Bsh[crs_l * Bsh_stride_vec + npq_l_vec] = convert_half4(val_f);
138+
}
139+
barrier(CLK_LOCAL_MEM_FENCE);
140+
141+
for(uint crs_l = 0; crs_l < BS_CRS; ++crs_l){
142+
T_FLOAT regA[TS_K];
143+
for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){
144+
regA[k_l_reg] = Ash[(T_y*TS_K + k_l_reg)*Ash_stride + crs_l];
145+
}
146+
for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){
147+
T_FLOAT4 regB = Bsh[crs_l*Bsh_stride_vec + T_x*TS_NPQ_VEC + npq_l_vec_reg];
148+
for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){
149+
regC[k_l_reg][npq_l_vec_reg] = mad((float)regA[k_l_reg], convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
150+
}
151+
}
152+
}
153+
barrier(CLK_LOCAL_MEM_FENCE);
154+
}
155+
156+
for(uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg){
157+
uint k_g = B_idx_K * BS_K + T_y * TS_K + k_l_reg;
158+
if(k_g >= K) continue;
159+
160+
for(uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg){
161+
uint npq_g_base = B_idx_NPQ * BS_NPQ + (T_x * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
162+
uint N_idx = npq_g_base / (OH*OW);
163+
uint pq_idx = npq_g_base % (OH*OW);
164+
uint OH_idx = pq_idx / OW;
165+
uint OW_idx = pq_idx % OW;
166+
167+
if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
168+
uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
169+
vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
170+
} else {
171+
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
172+
for (int v = 0; v < VEC_SIZE; ++v) {
173+
uint npq_g = npq_g_base + v;
174+
if (npq_g < NPQ) {
175+
uint N_idx_s = npq_g / (OH*OW);
176+
uint pq_idx_s = npq_g % (OH*OW);
177+
uint OH_idx_s = pq_idx_s / OW;
178+
uint OW_idx_s = pq_idx_s % OW;
179+
uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
180+
float val_f;
181+
switch(v) {
182+
case 0: val_f = res.s0; break;
183+
case 1: val_f = res.s1; break;
184+
case 2: val_f = res.s2; break;
185+
default:val_f = res.s3; break;
186+
}
187+
dst_data[dst_idx_s] = val_f;
188+
}
189+
}
190+
}
191+
}
192+
}
193+
}

0 commit comments

Comments
 (0)