@@ -6944,7 +6944,82 @@ void ggml_compute_forward_argsort(
6944
6944
}
6945
6945
6946
6946
// ggml_compute_forward_flash_attn_ext
6947
+ static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV (
6948
+ const ggml_fp16_t *Q,
6949
+ const char *K,
6950
+ const char *V,
6951
+ const int64_t DK,
6952
+ const int64_t DV,
6953
+ const float mask_value,
6954
+ const float scale,
6955
+ const float logit_softcap,
6956
+ const enum ggml_type v_type,
6957
+ ggml_vec_dot_t const kq_vec_dot,
6958
+ ggml_to_float_t const v_to_float,
6959
+ ggml_fp16_t *VKQ16,
6960
+ float *VKQ32,
6961
+ float *V32,
6962
+ float *sum,
6963
+ float *max_kq_value) {
6964
+ float s; // KQ value
6965
+ kq_vec_dot (DK, &s, 0 , K, 0 , Q, 0 , 1 );
6966
+
6967
+ s = s*scale; // scale KQ value
6968
+
6969
+ if (logit_softcap != 0 .0f ) {
6970
+ s = logit_softcap*tanhf (s);
6971
+ }
6972
+ s += mask_value; // apply mask
6973
+ float M = *max_kq_value;
6974
+ const float Mold = M;
6975
+
6976
+ float ms = 1 .0f ; // upon new higher max val, scale VKQ and KQ sum with this value
6977
+ float vs = 1 .0f ; // post-softmax KQ value, expf(s - M)
6978
+
6979
+ if (v_type == GGML_TYPE_F16) {
6980
+ if (s > M) {
6981
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
6982
+ M = s;
6983
+ ms = expf (Mold - M);
6984
+
6985
+ // V = V*expf(Mold - M)
6986
+ ggml_vec_scale_f16 (DV, VKQ16, ms);
6987
+ } else {
6988
+ // no new maximum, ms == 1.0f, vs != 1.0f
6989
+ vs = expf (s - M);
6990
+ }
6991
+
6992
+ // V += v*expf(s - M)
6993
+ ggml_vec_mad_f16 (DV, VKQ16, (const ggml_fp16_t *) V, vs);
6994
+ } else {
6995
+ if (s > M) {
6996
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
6997
+ M = s;
6998
+ ms = expf (Mold - M);
6947
6999
7000
+ // V = V*expf(Mold - M)
7001
+ ggml_vec_scale_f32 (DV, VKQ32, ms);
7002
+ } else {
7003
+ // no new maximum, ms == 1.0f, vs != 1.0f
7004
+ vs = expf (s - M);
7005
+ }
7006
+
7007
+ // V += v*expf(s - M)
7008
+ if (v_to_float) {
7009
+ v_to_float (V, V32, DV);
7010
+ ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
7011
+ } else {
7012
+ // V is F32
7013
+ ggml_vec_mad_f32 (DV, VKQ32, (const float *) V, vs);
7014
+ }
7015
+ }
7016
+ float S = *sum;
7017
+ S = S*ms + vs; // scale and increment sum with partial sum
7018
+ *sum = S;
7019
+ *max_kq_value = M;
7020
+ }
7021
+
7022
+ #define GGML_FLASH_ATTN_EXT_MAX_GQA 16
6948
7023
static void ggml_compute_forward_flash_attn_ext_f16 (
6949
7024
const ggml_compute_params * params,
6950
7025
const ggml_tensor * q,
@@ -6997,16 +7072,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
6997
7072
const int64_t rv3 = neq3/nev3;
6998
7073
6999
7074
// parallelize by q rows using ggml_vec_dot_f32
7075
+ const uint32_t n_head = neq2;
7076
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floor (log2 (n_head));
7000
7077
7001
- // total rows in q
7002
- const int nr = neq1*neq2*neq3;
7078
+ const uint32_t n_kv_head = nek2;
7079
+ const int n_gqa = n_head / n_kv_head;
7080
+ GGML_ASSERT (n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
7003
7081
7004
- // rows per thread
7005
- const int dr = (nr + nth - 1 )/nth ;
7082
+ // total groups in q
7083
+ const int ng = neq1*neq2*neq3/n_gqa ;
7006
7084
7007
- // row range for this thread
7008
- const int ir0 = dr*ith;
7009
- const int ir1 = MIN (ir0 + dr, nr);
7085
+ // groups per thread
7086
+ const int dg = (ng + nth - 1 )/nth;
7087
+
7088
+ // group range for this thread
7089
+ const int ig0 = dg*ith;
7090
+ const int ig1 = MIN (ig0 + dg, ng);
7010
7091
7011
7092
float scale = 1 .0f ;
7012
7093
float max_bias = 0 .0f ;
@@ -7020,9 +7101,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7020
7101
scale /= logit_softcap;
7021
7102
}
7022
7103
7023
- const uint32_t n_head = neq2;
7024
- const uint32_t n_head_log2 = 1u << (uint32_t ) floor (log2 (n_head));
7025
-
7026
7104
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
7027
7105
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
7028
7106
@@ -7034,28 +7112,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7034
7112
GGML_ASSERT (( q_to_vec_dot) && " fattn: unsupported K-type" );
7035
7113
GGML_ASSERT ((v->type == GGML_TYPE_F32 || v_to_float ) && " fattn: unsupported V-type" );
7036
7114
7037
- // loop over n_batch and n_head
7038
- for (int ir = ir0; ir < ir1; ++ir) {
7115
+ float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
7116
+ float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
7117
+ float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
7118
+ float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer
7119
+ ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
7120
+ ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16
7121
+ float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
7122
+
7123
+ for (int ig = ig0; ig < ig1; ++ig) {
7124
+ const int group_index = ig % n_kv_head;
7125
+ const int batch_index = ig / n_kv_head;
7039
7126
// q indices
7040
- const int iq3 = ir/(neq2*neq1);
7041
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
7042
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
7043
-
7044
- const uint32_t h = iq2; // head index
7045
- const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
7127
+ const int iq3 = 0 ;
7128
+ const int iq2 = group_index * n_gqa; // start head index
7129
+ const int iq1 = batch_index;
7130
+
7131
+ const int single_buffer_size = 1 *DK + 2 *DV;
7132
+ for (int i_gqa = 0 ; i_gqa < n_gqa; ++i_gqa) {
7133
+ S[i_gqa] = 0 .0f ;
7134
+ M[i_gqa] = -INFINITY;
7135
+ VKQ32 [i_gqa] = (float *) params->wdata + ith*(single_buffer_size*n_gqa + CACHE_LINE_SIZE_F32) + single_buffer_size*i_gqa;
7136
+ V32 [i_gqa] = (VKQ32[i_gqa] + 1 *DV);
7137
+ VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1 *DV);
7138
+ Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2 *DV);
7046
7139
7047
- float S = 0 .0f ; // sum
7048
- float M = -INFINITY; // maximum KQ value
7140
+ if (v->type == GGML_TYPE_F16) {
7141
+ memset (VKQ16[i_gqa], 0 , DV*sizeof (ggml_fp16_t ));
7142
+ } else {
7143
+ memset (VKQ32[i_gqa], 0 , DV*sizeof (float ));
7144
+ }
7049
7145
7050
- float * VKQ32 = (float *) params->wdata + ith*(1 *DK + 2 *DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
7051
- float * V32 = (VKQ32 + 1 *DV); // (temporary) FP32 V buffer
7052
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1 *DV); // (temporary) FP16 VKQ accumulator
7053
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2 *DV); // (temporary) buffer for Q converted to quantized/FP16
7146
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
7147
+ q_to_vec_dot (pq, Q_q[i_gqa], DK);
7054
7148
7055
- if (v->type == GGML_TYPE_F16) {
7056
- memset (VKQ16, 0 , DV*sizeof (ggml_fp16_t ));
7057
- } else {
7058
- memset (VKQ32, 0 , DV*sizeof (float ));
7149
+ const uint32_t h = iq2 + i_gqa;
7150
+ slope[i_gqa] = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
7059
7151
}
7060
7152
7061
7153
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
@@ -7068,99 +7160,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7068
7160
const int iv3 = iq3 / rv3;
7069
7161
const int iv2 = iq2 / rv2;
7070
7162
7071
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
7072
- q_to_vec_dot (pq, Q_q, DK);
7073
-
7074
7163
// online softmax / attention
7075
7164
// loop over n_kv and n_head_kv
7076
7165
// ref: https://arxiv.org/pdf/2112.05682.pdf
7077
7166
for (int64_t ic = 0 ; ic < nek1; ++ic) {
7078
- const float mv = mp ? slope* GGML_FP16_TO_FP32 (mp[ic]) : 0 .0f ;
7079
- if (mv == -INFINITY) {
7167
+ const float mp_value_base = mp ? GGML_FP16_TO_FP32 (mp[ic]) : 0 .0f ;
7168
+ if (mp_value_base == -INFINITY) {
7080
7169
continue ;
7081
7170
}
7082
-
7083
- float s; // KQ value
7084
-
7171
+ const char * v_data = (const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3);
7085
7172
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
7086
- kq_vec_dot (DK, &s, 0 , k_data, 0 , Q_q, 0 , 1 );
7087
-
7088
- s = s*scale; // scale KQ value
7089
-
7090
- if (logit_softcap != 0 .0f ) {
7091
- s = logit_softcap*tanhf (s);
7173
+ for (int i_gqa = 0 ; i_gqa < n_gqa; ++i_gqa) {
7174
+ const float mv = mp_value_base * slope[i_gqa];
7175
+ ggml_compute_forward_flash_attn_ext_f16_one_QKV (
7176
+ Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type ,
7177
+ kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
7092
7178
}
7179
+ }
7093
7180
7094
- s += mv; // apply mask
7095
-
7096
- const float Mold = M;
7097
-
7098
- float ms = 1 .0f ; // upon new higher max val, scale VKQ and KQ sum with this value
7099
- float vs = 1 .0f ; // post-softmax KQ value, expf(s - M)
7100
-
7101
- const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
7102
-
7181
+ for (int i = 0 ; i < n_gqa; ++i) {
7103
7182
if (v->type == GGML_TYPE_F16) {
7104
- if (s > M) {
7105
- // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
7106
- M = s;
7107
- ms = expf (Mold - M);
7108
-
7109
- // V = V*expf(Mold - M)
7110
- ggml_vec_scale_f16 (DV, VKQ16, ms);
7111
- } else {
7112
- // no new maximum, ms == 1.0f, vs != 1.0f
7113
- vs = expf (s - M);
7114
- }
7115
-
7116
- // V += v*expf(s - M)
7117
- ggml_vec_mad_f16 (DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
7118
- } else {
7119
- if (s > M) {
7120
- // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
7121
- M = s;
7122
- ms = expf (Mold - M);
7123
-
7124
- // V = V*expf(Mold - M)
7125
- ggml_vec_scale_f32 (DV, VKQ32, ms);
7126
- } else {
7127
- // no new maximum, ms == 1.0f, vs != 1.0f
7128
- vs = expf (s - M);
7129
- }
7130
-
7131
- // V += v*expf(s - M)
7132
- if (v_to_float) {
7133
- v_to_float (v_data, V32, DV);
7134
- ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
7135
- } else {
7136
- // V is F32
7137
- ggml_vec_mad_f32 (DV, VKQ32, (const float *) v_data, vs);
7183
+ for (int64_t d = 0 ; d < DV; ++d) {
7184
+ VKQ32[i][d] = GGML_FP16_TO_FP32 (VKQ16[i][d]);
7138
7185
}
7139
7186
}
7140
7187
7141
- S = S*ms + vs; // scale and increment sum with partial sum
7142
- }
7188
+ // V /= S
7189
+ const float S_inv = 1 .0f /S[i];
7190
+ ggml_vec_scale_f32 (DV, VKQ32[i], S_inv);
7143
7191
7144
- if (v->type == GGML_TYPE_F16) {
7145
- for (int64_t d = 0 ; d < DV; ++d) {
7146
- VKQ32[d] = GGML_FP16_TO_FP32 (VKQ16[d]);
7147
- }
7148
- }
7192
+ // dst indices
7193
+ const int i1 = iq1;
7194
+ const int i2 = iq2 + i;
7195
+ const int i3 = iq3;
7149
7196
7150
- // V /= S
7151
- const float S_inv = 1 .0f /S;
7152
- ggml_vec_scale_f32 (DV, VKQ32, S_inv);
7197
+ // original
7198
+ // memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7153
7199
7154
- // dst indices
7155
- const int i1 = iq1;
7156
- const int i2 = iq2;
7157
- const int i3 = iq3;
7158
-
7159
- // original
7160
- // memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7161
-
7162
- // permute(0, 2, 1, 3)
7163
- memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
7200
+ // permute(0, 2, 1, 3)
7201
+ memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
7202
+ }
7164
7203
}
7165
7204
}
7166
7205
0 commit comments