@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5763
5763
5764
5764
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
5765
5765
5766
- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
5767
- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
5768
- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
5769
- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
5766
+ const int ne00 = src0->ne [0 ];
5767
+ const int ne01 = src0->ne [1 ];
5768
+ const int ne02 = src0->ne [2 ];
5769
+ const int ne03 = src0->ne [3 ];
5770
+
5771
+ const cl_long nb01 = src0->nb [1 ];
5772
+ const cl_long nb02 = src0->nb [2 ];
5773
+ const cl_long nb03 = src0->nb [3 ];
5774
+
5775
+ const int ne12 = src1 ? src1->ne [2 ] : 0 ;
5776
+ const int ne13 = src1 ? src1->ne [3 ] : 0 ;
5777
+
5778
+ const cl_long nb11 = src1 ? src1->nb [1 ] : 0 ;
5779
+ const cl_long nb12 = src1 ? src1->nb [2 ] : 0 ;
5780
+ const cl_long nb13 = src1 ? src1->nb [3 ] : 0 ;
5781
+
5782
+ const cl_long nb1 = dst->nb [1 ];
5783
+ const cl_long nb2 = dst->nb [2 ];
5784
+ const cl_long nb3 = dst->nb [3 ];
5770
5785
5771
5786
float scale, max_bias;
5772
5787
memcpy (&scale, dst->op_params + 0 , sizeof (float ));
5773
5788
memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
5774
5789
5775
- const int nrows_x = ggml_nrows (src0);
5776
- const int nrows_y = src0->ne [1 ];
5777
-
5778
- const int n_head = nrows_x/nrows_y;
5790
+ const int n_head = src0->ne [2 ];
5779
5791
const int n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
5780
5792
5781
5793
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5820
5832
CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
5821
5833
CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
5822
5834
CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
5823
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
5824
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
5825
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (float ), &scale));
5826
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &max_bias));
5827
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &m0));
5828
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &m1));
5829
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &n_head_log2));
5835
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
5836
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
5837
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
5838
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
5839
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne13));
5840
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
5841
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
5842
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb13));
5843
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb1));
5844
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb2));
5845
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb3));
5846
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float ), &scale));
5847
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (float ), &max_bias));
5848
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &m0));
5849
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &m1));
5850
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &n_head_log2));
5830
5851
5831
5852
size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
5832
5853
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments