Skip to content

Commit 349ea79

Browse files
use max work group size for device to replace the magic number (#14732)
1 parent 670e136 commit 349ea79

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3530,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
35303530
SYCL_CHECK(CHECK_TRY_ERROR(
35313531
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
35323532

3533+
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
3534+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
3535+
35333536
{
3534-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
3537+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
35353538
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
35363539
sycl_launch(stream, [&](sycl::handler & cgh) {
35373540
sycl::local_accessor<int, 0> src1_row_acc(cgh);
@@ -3575,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
35753578
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
35763579

35773580
{
3578-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
3581+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
35793582
sycl::range<3> grid_dims(1, 1, num_src1_rows);
35803583
sycl_launch(stream, [&](sycl::handler & cgh) {
35813584
const char *__restrict dst_contiguous_get =

0 commit comments

Comments
 (0)