@@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16(
149
149
VKQ += V_k*KQ2[k0/2 ];
150
150
}
151
151
}
152
+
153
+ __syncthreads ();
152
154
}
153
155
154
156
if (tid >= D) {
@@ -547,7 +549,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
547
549
dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
548
550
}
549
551
550
- constexpr int nwarps = ((D) + WARP_SIZE - 1 ) / WARP_SIZE;
552
+ constexpr int nwarps = (D + WARP_SIZE - 1 ) / WARP_SIZE;
551
553
constexpr dim3 block_dim (WARP_SIZE, nwarps, 1 );
552
554
const dim3 blocks_num (parallel_blocks*Q->ne [1 ], Q->ne [2 ], Q->ne [3 ]);
553
555
const int shmem = 0 ;
@@ -561,7 +563,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
561
563
(const char *) K->data ,
562
564
(const char *) V->data ,
563
565
mask ? ((const char *) mask->data ) : nullptr ,
564
- ( parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
566
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
565
567
scale,
566
568
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
567
569
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -572,7 +574,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
572
574
);
573
575
CUDA_CHECK (cudaGetLastError ());
574
576
575
- if (( parallel_blocks) == 1 ) {
577
+ if (parallel_blocks == 1 ) {
576
578
return ;
577
579
}
578
580
0 commit comments