@@ -3010,8 +3010,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3010
3010
SYCL_CHECK (CHECK_TRY_ERROR (
3011
3011
stream->memset (dev_cur_src1_row.get (), 0 , sizeof (int ))));
3012
3012
3013
+ const unsigned int max_work_group_size = ggml_sycl_info ().work_group_size (ctx.device );
3014
+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
3015
+
3013
3016
{
3014
- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, 768u ));
3017
+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size ));
3015
3018
sycl::range<3 > grid_dims (1 , n_ids, ids->ne [1 ]);
3016
3019
stream->submit ([&](sycl::handler &cgh) {
3017
3020
sycl::local_accessor<int , 0 > src1_row_acc (cgh);
@@ -3056,7 +3059,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3056
3059
ggml_sycl_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
3057
3060
3058
3061
{
3059
- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, 768u ));
3062
+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size ));
3060
3063
sycl::range<3 > grid_dims (1 , 1 , num_src1_rows);
3061
3064
stream->submit ([&](sycl::handler &cgh) {
3062
3065
const char *__restrict dst_contiguous_get =
0 commit comments