@@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
840
840
841
841
float scale = (1 .0f / sqrt ((float )d_head));
842
842
843
- // if (flash_attn) {
844
- // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
845
- // }
843
+ int kv_pad = 0 ;
844
+ // if (flash_attn) {
845
+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
846
+ // }
846
847
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
847
848
GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
848
849
849
850
bool can_use_flash_attn = true ;
851
+ can_use_flash_attn = can_use_flash_attn && (
852
+ d_head == 64 ||
853
+ d_head == 80 ||
854
+ d_head == 96 ||
855
+ d_head == 112 ||
856
+ d_head == 128 ||
857
+ d_head == 256
858
+ );
859
+ #if 0
850
860
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
851
- can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0 ; // double check
852
-
853
- // cuda max d_head seems to be 256, cpu does seem to work with 512
854
- can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
861
+ #else
862
+ if (can_use_flash_attn && L_k % 256 != 0 ) {
863
+ // TODO(Green-Sky): might be worth just padding by default
864
+ if (L_k == 77 || L_k == 4208 || L_k == 3952 ) {
865
+ kv_pad = GGML_PAD (L_k, 256 ) - L_k;
866
+ } else {
867
+ can_use_flash_attn = false ;
868
+ }
869
+ }
870
+ #endif
855
871
856
872
if (mask != nullptr ) {
857
873
// TODO(Green-Sky): figure out if we can bend t5 to work too
@@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
864
880
ggml_tensor* kqv = nullptr ;
865
881
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
866
882
if (can_use_flash_attn && flash_attn) {
867
- // LOG_DEBUG("using flash attention");
883
+ // LOG_DEBUG(" uses flash attention");
884
+ if (kv_pad != 0 ) {
885
+ // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
886
+ k = ggml_pad (ctx, k, 0 , kv_pad, 0 , 0 );
887
+ }
868
888
k = ggml_cast (ctx, k, GGML_TYPE_F16);
869
889
870
890
v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
871
891
v = ggml_reshape_3d (ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
892
+ if (kv_pad != 0 ) {
893
+ v = ggml_pad (ctx, v, 0 , kv_pad, 0 , 0 );
894
+ }
872
895
v = ggml_cast (ctx, v, GGML_TYPE_F16);
873
896
874
897
if (mask != nullptr ) {
0 commit comments