Skip to content

Commit 063d99a

Browse files
authored
[SYCL] fix scratch size of softmax (#8642)
1 parent 081fe43 commit 063d99a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
152152

153153
const sycl::range<3> block_dims(1, 1, nth);
154154
const sycl::range<3> block_nums(1, 1, nrows_x);
155-
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
155+
const size_t n_val_tmp = nth / WARP_SIZE;
156+
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
156157

157158
const uint32_t n_head_kv = nrows_x/nrows_y;
158159
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));

0 commit comments

Comments
 (0)