Skip to content

Commit e079bff

Browse files
authored
cuda : fix FA Q src index (1 -> 0) (ggml-org#9374)
1 parent 3f7ccfd commit e079bff

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
152152
} \
153153

154154
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155-
ggml_tensor * Q = dst->src[1];
155+
ggml_tensor * Q = dst->src[0];
156156
ggml_tensor * K = dst->src[1];
157157
ggml_tensor * V = dst->src[2];
158158

@@ -227,7 +227,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
227227
} \
228228

229229
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230-
ggml_tensor * Q = dst->src[1];
230+
ggml_tensor * Q = dst->src[0];
231231
ggml_tensor * K = dst->src[1];
232232
ggml_tensor * V = dst->src[2];
233233

0 commit comments

Comments
 (0)