File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -404,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
404
404
if (ic0 + j_VKQ >= ne01) {
405
405
return ;
406
406
}
407
- const int j_dst = (ic0 + j_VKQ)*gridDim .y + blockIdx .y ;
408
407
409
408
float KQ_rowsum_j;
410
409
if (std::is_same<KQ_acc_t, float >::value) {
@@ -413,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
413
412
KQ_rowsum_j = __low2float (KQ_rowsum_h2[j0/nwarps]) + __high2float (KQ_rowsum_h2[j0/nwarps]);
414
413
}
415
414
415
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim .y + blockIdx .y ;
416
+
416
417
#pragma unroll
417
418
for (int i0 = 0 ; i0 < D; i0 += warp_size) {
418
419
const int i = i0 + threadIdx .x ;
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
423
424
if (gridDim .y == 1 ) {
424
425
dst_val /= KQ_rowsum_j;
425
426
}
426
- dst[((sequence*ne01 + j_dst)*ne02 + head)* D + tid ] = dst_val;
427
+ dst[j_dst_unrolled* D + i ] = dst_val;
427
428
}
428
429
429
430
if (gridDim .y == 1 || threadIdx .x != 0 ) {
@@ -437,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
437
438
dst_meta_val.x = __low2float (KQ_max_h2[j0/nwarps]);
438
439
}
439
440
dst_meta_val.y = KQ_rowsum_j;
440
- dst_meta[((ic0 + j_VKQ)* gridDim . z + blockIdx . z ) * gridDim . y + blockIdx . y ] = dst_meta_val;
441
+ dst_meta[j_dst_unrolled ] = dst_meta_val;
441
442
}
442
443
#else
443
444
GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask);
You can’t perform that action at this time.
0 commit comments