@@ -3530,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3530
3530
SYCL_CHECK (CHECK_TRY_ERROR (
3531
3531
stream->memset (dev_cur_src1_row.get (), 0 , sizeof (int ))));
3532
3532
3533
+ const unsigned int max_work_group_size = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
3534
+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
3535
+
3533
3536
{
3534
- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, 768u ));
3537
+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size ));
3535
3538
sycl::range<3 > grid_dims (1 , n_ids, ids->ne [1 ]);
3536
3539
sycl_launch (stream, [&](sycl::handler & cgh) {
3537
3540
sycl::local_accessor<int , 0 > src1_row_acc (cgh);
@@ -3575,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3575
3578
ggml_sycl_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
3576
3579
3577
3580
{
3578
- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, 768u ));
3581
+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size ));
3579
3582
sycl::range<3 > grid_dims (1 , 1 , num_src1_rows);
3580
3583
sycl_launch (stream, [&](sycl::handler & cgh) {
3581
3584
const char *__restrict dst_contiguous_get =
0 commit comments