Skip to content

Commit 6bdda13

Browse files
authored
opencl: add tiled mul_mat_f16_f32 (#14535)
* add tiled mul_mat_f16_f32 * fix trailing whitespace * add insightful comments
1 parent 0b88557 commit 6bdda13

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ set(GGML_OPENCL_KERNELS
104104
tanh
105105
pad
106106
repeat
107+
mul_mat_f16_f32
107108
)
108109

109110
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ struct ggml_backend_opencl_context {
368368
cl_program program_mul_mv_f16_f32;
369369
cl_program program_mul_mv_f32_f32;
370370
cl_program program_mul;
371+
cl_program program_mul_mat_f16_f32_tiled;
371372
cl_program program_div;
372373
cl_program program_sub;
373374
cl_program program_norm;
@@ -422,6 +423,7 @@ struct ggml_backend_opencl_context {
422423
cl_kernel kernel_mul_mat_f16_f32_1row;
423424
cl_kernel kernel_mul_mat_f16_f32;
424425
cl_kernel kernel_mul_mat_f16_f32_l4;
426+
cl_kernel kernel_mul_mat_f16_f32_tiled;
425427
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
426428
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
427429
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@@ -1015,6 +1017,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10151017
GGML_LOG_CONT(".");
10161018
}
10171019

1020+
// mul_mat_f16_f32_tiled
1021+
{
1022+
#ifdef GGML_OPENCL_EMBED_KERNELS
1023+
const std::string kernel_src {
1024+
#include "mul_mat_f16_f32.cl.h"
1025+
};
1026+
#else
1027+
const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
1028+
#endif
1029+
backend_ctx->program_mul_mat_f16_f32_tiled =
1030+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1031+
1032+
CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
1033+
GGML_LOG_CONT(".");
1034+
}
1035+
10181036
// mul
10191037
{
10201038
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -4927,6 +4945,58 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
49274945
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
49284946
}
49294947

4948+
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4949+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4950+
4951+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4952+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4953+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4954+
4955+
cl_ulong offset0 = extra0->offset + src0->view_offs;
4956+
cl_ulong offset1 = extra1->offset + src1->view_offs;
4957+
cl_ulong offsetd = extrad->offset + dst->view_offs;
4958+
4959+
const int M = src0->ne[1];
4960+
const int N = src1->ne[1];
4961+
const int K = src0->ne[0];
4962+
4963+
cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
4964+
4965+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
4966+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
4967+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
4968+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
4969+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
4970+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
4971+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
4972+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
4973+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
4974+
4975+
// Tiling parameters. These need to be tuned for optimal performance.
4976+
// They must match the #defines in the kernel mul_mat_f16_f32.cl.
4977+
//
4978+
// OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
4979+
// TPWM / TPWN: Threads per Work-group. This is the work-group size.
4980+
// OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
4981+
//
4982+
// The following relationships must hold:
4983+
// OPWM = TPWM * OPTM
4984+
// OPWN = TPWN * OPTN
4985+
//
4986+
const int OPWM = 64;
4987+
const int OPWN = 64;
4988+
const int TPWM = 16;
4989+
const int TPWN = 8;
4990+
4991+
size_t local_work_size[2] = { TPWM, TPWN };
4992+
size_t global_work_size[2] = {
4993+
(size_t) ((M + OPWM - 1) / OPWM) * TPWM,
4994+
(size_t) ((N + OPWN - 1) / OPWN) * TPWN,
4995+
};
4996+
4997+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
4998+
}
4999+
49305000
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
49315001
GGML_ASSERT(src0);
49325002
GGML_ASSERT(src0->extra);
@@ -4940,6 +5010,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
49405010

49415011
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
49425012

5013+
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
5014+
src0->ne[1] > 32 && // M > 32
5015+
src1->ne[1] > 32 && // N > 32
5016+
src0->ne[0] > 32 && // K > 32
5017+
src0->ne[2] == 1 && src0->ne[3] == 1 &&
5018+
src1->ne[2] == 1 && src1->ne[3] == 1 &&
5019+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
5020+
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
5021+
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5022+
return;
5023+
}
5024+
49435025
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
49445026
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
49455027
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#if defined(cl_qcom_reqd_sub_group_size)
4+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
5+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
6+
#else
7+
#define REQD_SUBGROUP_SIZE_128
8+
#endif
9+
10+
#define OPWM 64
11+
#define OPWN 64
12+
#define CPWK 8
13+
#define OPTM 4
14+
#define OPTN 8
15+
16+
#define WG_M (OPWM / OPTM)
17+
#define WG_N (OPWN / OPTN)
18+
#define VEC_K (CPWK / 4)
19+
20+
REQD_SUBGROUP_SIZE_128
21+
__kernel void mul_mat_f16_f32(
22+
const int M, const int N, const int K,
23+
__global const void* A_void, ulong A_offset,
24+
__global const void* B_void, ulong B_offset,
25+
__global void* C_void, ulong C_offset) {
26+
27+
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
28+
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
29+
__global float* C = (__global float*)((__global char*)C_void + C_offset);
30+
31+
const int lidm = get_local_id(0);
32+
const int lidn = get_local_id(1);
33+
const int lid = lidn * WG_M + lidm;
34+
35+
const int offsetM = get_group_id(0) * OPWM;
36+
const int offsetN = get_group_id(1) * OPWN;
37+
38+
__local half4 Alocal[OPWM][VEC_K];
39+
__local float4 Blocal[OPWN][VEC_K];
40+
41+
float sum[OPTM][OPTN];
42+
43+
for (int wm = 0; wm < OPTM; wm++) {
44+
for (int wn = 0; wn < OPTN; wn++) {
45+
sum[wm][wn] = 0.0f;
46+
}
47+
}
48+
49+
const int numTiles = (K + CPWK - 1) / CPWK;
50+
51+
const int load_row_a = lid % OPWM;
52+
const int load_vec_k_a = lid / OPWM;
53+
const int global_row_a = offsetM + load_row_a;
54+
55+
const int load_row_b = lid % OPWN;
56+
const int load_vec_k_b = lid / OPWN;
57+
const int global_row_b = offsetN + load_row_b;
58+
59+
for (int t = 0; t < numTiles; t++) {
60+
const int k_start = t * CPWK;
61+
const int k_vec_start_a = k_start + load_vec_k_a * 4;
62+
const int k_vec_start_b = k_start + load_vec_k_b * 4;
63+
64+
if (global_row_a < M && k_vec_start_a < K) {
65+
if (k_vec_start_a + 3 < K) {
66+
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
67+
} else {
68+
half4 tempA = (half4)(0.0h);
69+
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
70+
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
71+
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
72+
Alocal[load_row_a][load_vec_k_a] = tempA;
73+
}
74+
} else {
75+
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
76+
}
77+
78+
if (global_row_b < N && k_vec_start_b < K) {
79+
if (k_vec_start_b + 3 < K) {
80+
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
81+
} else {
82+
float4 tempB = (float4)(0.0f);
83+
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
84+
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
85+
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
86+
Blocal[load_row_b][load_vec_k_b] = tempB;
87+
}
88+
} else {
89+
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
90+
}
91+
92+
barrier(CLK_LOCAL_MEM_FENCE);
93+
94+
#pragma unroll
95+
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
96+
float4 a_fvecs[OPTM];
97+
int current_row_a = lidm;
98+
for (int wm = 0; wm < OPTM; wm++) {
99+
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
100+
current_row_a += WG_M;
101+
}
102+
103+
float4 b_fvecs[OPTN];
104+
int current_row_b = lidn;
105+
for (int wn = 0; wn < OPTN; wn++) {
106+
b_fvecs[wn] = Blocal[current_row_b][k_vec];
107+
current_row_b += WG_N;
108+
}
109+
110+
for (int wm = 0; wm < OPTM; wm++) {
111+
for (int wn = 0; wn < OPTN; wn++) {
112+
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
113+
}
114+
}
115+
}
116+
barrier(CLK_LOCAL_MEM_FENCE);
117+
}
118+
119+
for (int wm = 0; wm < OPTM; wm++) {
120+
int globalRow = offsetM + lidm + wm * WG_M;
121+
if (globalRow < M) {
122+
for (int wn = 0; wn < OPTN; wn++) {
123+
int globalCol = offsetN + lidn + wn * WG_N;
124+
if (globalCol < N) {
125+
C[globalCol * M + globalRow] = sum[wm][wn];
126+
}
127+
}
128+
}
129+
}
130+
}

0 commit comments

Comments
 (0)