Skip to content

Commit 6a3b842

Browse files
fix flash_attn_vec_f16 race condition
1 parent 34f93bb commit 6a3b842

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

ggml-cuda/fattn.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16(
149149
VKQ += V_k*KQ2[k0/2];
150150
}
151151
}
152+
153+
__syncthreads();
152154
}
153155

154156
if (tid >= D) {
@@ -547,7 +549,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
547549
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
548550
}
549551

550-
constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE;
552+
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
551553
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
552554
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
553555
const int shmem = 0;
@@ -561,7 +563,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
561563
(const char *) K->data,
562564
(const char *) V->data,
563565
mask ? ((const char *) mask->data) : nullptr,
564-
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
566+
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
565567
scale,
566568
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
567569
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -572,7 +574,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
572574
);
573575
CUDA_CHECK(cudaGetLastError());
574576

575-
if ((parallel_blocks) == 1) {
577+
if (parallel_blocks == 1) {
576578
return;
577579
}
578580

0 commit comments

Comments
 (0)