Skip to content

Commit 326e4e2

Browse files
CUDA: fix WMMA FA kernel
1 parent 710757f commit 326e4e2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
404404
if (ic0 + j_VKQ >= ne01) {
405405
return;
406406
}
407-
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
408407

409408
float KQ_rowsum_j;
410409
if (std::is_same<KQ_acc_t, float>::value) {
@@ -413,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
413412
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
414413
}
415414

415+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
416+
416417
#pragma unroll
417418
for (int i0 = 0; i0 < D; i0 += warp_size) {
418419
const int i = i0 + threadIdx.x;
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
423424
if (gridDim.y == 1) {
424425
dst_val /= KQ_rowsum_j;
425426
}
426-
dst[((sequence*ne01 + j_dst)*ne02 + head)*D + tid] = dst_val;
427+
dst[j_dst_unrolled*D + i] = dst_val;
427428
}
428429

429430
if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -437,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
437438
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
438439
}
439440
dst_meta_val.y = KQ_rowsum_j;
440-
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
441+
dst_meta[j_dst_unrolled] = dst_meta_val;
441442
}
442443
#else
443444
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);

0 commit comments

Comments
 (0)