Skip to content

Commit 21a9e3c

Browse files
lhezqnixsynapse
authored andcommitted
opencl : broadcast for soft_max (ggml-org#14510)
1 parent 85f709a commit 21a9e3c

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
57635763

57645764
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
57655765

5766-
const int ne00 = src0 ? src0->ne[0] : 0;
5767-
const int ne01 = src0 ? src0->ne[1] : 0;
5768-
const int ne02 = src0 ? src0->ne[2] : 0;
5769-
const int ne03 = src0 ? src0->ne[3] : 0;
5766+
const int ne00 = src0->ne[0];
5767+
const int ne01 = src0->ne[1];
5768+
const int ne02 = src0->ne[2];
5769+
const int ne03 = src0->ne[3];
5770+
5771+
const cl_long nb01 = src0->nb[1];
5772+
const cl_long nb02 = src0->nb[2];
5773+
const cl_long nb03 = src0->nb[3];
5774+
5775+
const int ne12 = src1 ? src1->ne[2] : 0;
5776+
const int ne13 = src1 ? src1->ne[3] : 0;
5777+
5778+
const cl_long nb11 = src1 ? src1->nb[1] : 0;
5779+
const cl_long nb12 = src1 ? src1->nb[2] : 0;
5780+
const cl_long nb13 = src1 ? src1->nb[3] : 0;
5781+
5782+
const cl_long nb1 = dst->nb[1];
5783+
const cl_long nb2 = dst->nb[2];
5784+
const cl_long nb3 = dst->nb[3];
57705785

57715786
float scale, max_bias;
57725787
memcpy(&scale, dst->op_params + 0, sizeof(float));
57735788
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
57745789

5775-
const int nrows_x = ggml_nrows(src0);
5776-
const int nrows_y = src0->ne[1];
5777-
5778-
const int n_head = nrows_x/nrows_y;
5790+
const int n_head = src0->ne[2];
57795791
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
57805792

57815793
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
58205832
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
58215833
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
58225834
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5823-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5824-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5825-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
5826-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
5827-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
5828-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
5829-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
5835+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
5836+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
5837+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
5838+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5839+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
5840+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
5841+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
5842+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
5843+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
5844+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
5845+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
5846+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
5847+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
5848+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
5849+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
5850+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
58305851

58315852
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
58325853
size_t local_work_size[] = {(size_t)nth, 1, 1};

0 commit comments

Comments
 (0)