Skip to content

Commit ab835f7

Browse files
authored
fix: correct head dim check and L_k padding of flash attention (#736)
1 parent 26f3f61 commit ab835f7

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

ggml_extend.hpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
840840

841841
float scale = (1.0f / sqrt((float)d_head));
842842

843-
// if (flash_attn) {
844-
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
845-
// }
843+
int kv_pad = 0;
844+
//if (flash_attn) {
845+
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
846+
//}
846847
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
847848
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
848849

849850
bool can_use_flash_attn = true;
851+
can_use_flash_attn = can_use_flash_attn && (
852+
d_head == 64 ||
853+
d_head == 80 ||
854+
d_head == 96 ||
855+
d_head == 112 ||
856+
d_head == 128 ||
857+
d_head == 256
858+
);
859+
#if 0
850860
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
851-
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
852-
853-
// cuda max d_head seems to be 256, cpu does seem to work with 512
854-
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
861+
#else
862+
if (can_use_flash_attn && L_k % 256 != 0) {
863+
// TODO(Green-Sky): might be worth just padding by default
864+
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
865+
kv_pad = GGML_PAD(L_k, 256) - L_k;
866+
} else {
867+
can_use_flash_attn = false;
868+
}
869+
}
870+
#endif
855871

856872
if (mask != nullptr) {
857873
// TODO(Green-Sky): figure out if we can bend t5 to work too
@@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
864880
ggml_tensor* kqv = nullptr;
865881
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
866882
if (can_use_flash_attn && flash_attn) {
867-
// LOG_DEBUG("using flash attention");
883+
//LOG_DEBUG(" uses flash attention");
884+
if (kv_pad != 0) {
885+
//LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
886+
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
887+
}
868888
k = ggml_cast(ctx, k, GGML_TYPE_F16);
869889

870890
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
871891
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
892+
if (kv_pad != 0) {
893+
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
894+
}
872895
v = ggml_cast(ctx, v, GGML_TYPE_F16);
873896

874897
if (mask != nullptr) {

0 commit comments

Comments
 (0)