Skip to content

Commit 259b135

Browse files
authored
Merge pull request #9 from ZelinMa557/optim2
Optim2
2 parents b73c20c + 921b47c commit 259b135

File tree

2 files changed

+149
-106
lines changed

2 files changed

+149
-106
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2804,7 +2804,11 @@ struct ggml_cplan ggml_graph_plan(
28042804
const int64_t ne10 = node->src[1]->ne[0]; // DK
28052805
const int64_t ne20 = node->src[2]->ne[0]; // DV
28062806

2807-
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
2807+
const int64_t ne02 = node->src[0]->ne[2]; // n_head
2808+
const int64_t ne12 = node->src[1]->ne[2]; // n_kv_head
2809+
const int64_t n_gqa = ne02/ne12;
2810+
2811+
cur = sizeof(float)*n_gqa*(1*ne10 + 2*ne20)*n_tasks; // ngqa * (1x head size K + 2x head size V) (per thread)
28082812
} break;
28092813
case GGML_OP_FLASH_ATTN_BACK:
28102814
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 144 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6944,7 +6944,82 @@ void ggml_compute_forward_argsort(
69446944
}
69456945

69466946
// ggml_compute_forward_flash_attn_ext
6947+
static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV(
6948+
const ggml_fp16_t *Q,
6949+
const char *K,
6950+
const char *V,
6951+
const int64_t DK,
6952+
const int64_t DV,
6953+
const float mask_value,
6954+
const float scale,
6955+
const float logit_softcap,
6956+
const enum ggml_type v_type,
6957+
ggml_vec_dot_t const kq_vec_dot,
6958+
ggml_to_float_t const v_to_float,
6959+
ggml_fp16_t *VKQ16,
6960+
float *VKQ32,
6961+
float *V32,
6962+
float *sum,
6963+
float *max_kq_value) {
6964+
float s; // KQ value
6965+
kq_vec_dot(DK, &s, 0, K, 0, Q, 0, 1);
6966+
6967+
s = s*scale; // scale KQ value
6968+
6969+
if (logit_softcap != 0.0f) {
6970+
s = logit_softcap*tanhf(s);
6971+
}
6972+
s += mask_value; // apply mask
6973+
float M = *max_kq_value;
6974+
const float Mold = M;
6975+
6976+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
6977+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
6978+
6979+
if (v_type == GGML_TYPE_F16) {
6980+
if (s > M) {
6981+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
6982+
M = s;
6983+
ms = expf(Mold - M);
6984+
6985+
// V = V*expf(Mold - M)
6986+
ggml_vec_scale_f16(DV, VKQ16, ms);
6987+
} else {
6988+
// no new maximum, ms == 1.0f, vs != 1.0f
6989+
vs = expf(s - M);
6990+
}
6991+
6992+
// V += v*expf(s - M)
6993+
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) V, vs);
6994+
} else {
6995+
if (s > M) {
6996+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
6997+
M = s;
6998+
ms = expf(Mold - M);
69476999

7000+
// V = V*expf(Mold - M)
7001+
ggml_vec_scale_f32(DV, VKQ32, ms);
7002+
} else {
7003+
// no new maximum, ms == 1.0f, vs != 1.0f
7004+
vs = expf(s - M);
7005+
}
7006+
7007+
// V += v*expf(s - M)
7008+
if (v_to_float) {
7009+
v_to_float(V, V32, DV);
7010+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
7011+
} else {
7012+
// V is F32
7013+
ggml_vec_mad_f32(DV, VKQ32, (const float *) V, vs);
7014+
}
7015+
}
7016+
float S = *sum;
7017+
S = S*ms + vs; // scale and increment sum with partial sum
7018+
*sum = S;
7019+
*max_kq_value = M;
7020+
}
7021+
7022+
#define GGML_FLASH_ATTN_EXT_MAX_GQA 16
69487023
static void ggml_compute_forward_flash_attn_ext_f16(
69497024
const ggml_compute_params * params,
69507025
const ggml_tensor * q,
@@ -6997,16 +7072,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
69977072
const int64_t rv3 = neq3/nev3;
69987073

69997074
// parallelize by q rows using ggml_vec_dot_f32
7075+
const uint32_t n_head = neq2;
7076+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
70007077

7001-
// total rows in q
7002-
const int nr = neq1*neq2*neq3;
7078+
const uint32_t n_kv_head = nek2;
7079+
const int n_gqa = n_head / n_kv_head;
7080+
GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
70037081

7004-
// rows per thread
7005-
const int dr = (nr + nth - 1)/nth;
7082+
// total groups in q
7083+
const int ng = neq1*neq2*neq3/n_gqa;
70067084

7007-
// row range for this thread
7008-
const int ir0 = dr*ith;
7009-
const int ir1 = MIN(ir0 + dr, nr);
7085+
// groups per thread
7086+
const int dg = (ng + nth - 1)/nth;
7087+
7088+
// group range for this thread
7089+
const int ig0 = dg*ith;
7090+
const int ig1 = MIN(ig0 + dg, ng);
70107091

70117092
float scale = 1.0f;
70127093
float max_bias = 0.0f;
@@ -7020,9 +7101,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70207101
scale /= logit_softcap;
70217102
}
70227103

7023-
const uint32_t n_head = neq2;
7024-
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
7025-
70267104
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
70277105
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
70287106

@@ -7034,28 +7112,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70347112
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
70357113
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
70367114

7037-
// loop over n_batch and n_head
7038-
for (int ir = ir0; ir < ir1; ++ir) {
7115+
float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
7116+
float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
7117+
float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
7118+
float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer
7119+
ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
7120+
ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16
7121+
float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
7122+
7123+
for (int ig = ig0; ig < ig1; ++ig) {
7124+
const int group_index = ig % n_kv_head;
7125+
const int batch_index = ig / n_kv_head;
70397126
// q indices
7040-
const int iq3 = ir/(neq2*neq1);
7041-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
7042-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
7043-
7044-
const uint32_t h = iq2; // head index
7045-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
7127+
const int iq3 = 0;
7128+
const int iq2 = group_index * n_gqa; // start head index
7129+
const int iq1 = batch_index;
7130+
7131+
const int single_buffer_size = 1*DK + 2*DV;
7132+
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
7133+
S[i_gqa] = 0.0f;
7134+
M[i_gqa] = -INFINITY;
7135+
VKQ32 [i_gqa] = (float *) params->wdata + ith*(single_buffer_size*n_gqa + CACHE_LINE_SIZE_F32) + single_buffer_size*i_gqa;
7136+
V32 [i_gqa] = (VKQ32[i_gqa] + 1*DV);
7137+
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*DV);
7138+
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*DV);
70467139

7047-
float S = 0.0f; // sum
7048-
float M = -INFINITY; // maximum KQ value
7140+
if (v->type == GGML_TYPE_F16) {
7141+
memset(VKQ16[i_gqa], 0, DV*sizeof(ggml_fp16_t));
7142+
} else {
7143+
memset(VKQ32[i_gqa], 0, DV*sizeof(float));
7144+
}
70497145

7050-
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
7051-
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
7052-
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
7053-
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
7146+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
7147+
q_to_vec_dot(pq, Q_q[i_gqa], DK);
70547148

7055-
if (v->type == GGML_TYPE_F16) {
7056-
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
7057-
} else {
7058-
memset(VKQ32, 0, DV*sizeof(float));
7149+
const uint32_t h = iq2 + i_gqa;
7150+
slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
70597151
}
70607152

70617153
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -7068,99 +7160,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70687160
const int iv3 = iq3 / rv3;
70697161
const int iv2 = iq2 / rv2;
70707162

7071-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
7072-
q_to_vec_dot(pq, Q_q, DK);
7073-
70747163
// online softmax / attention
70757164
// loop over n_kv and n_head_kv
70767165
// ref: https://arxiv.org/pdf/2112.05682.pdf
70777166
for (int64_t ic = 0; ic < nek1; ++ic) {
7078-
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
7079-
if (mv == -INFINITY) {
7167+
const float mp_value_base = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
7168+
if (mp_value_base == -INFINITY) {
70807169
continue;
70817170
}
7082-
7083-
float s; // KQ value
7084-
7171+
const char * v_data = (const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3);
70857172
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
7086-
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
7087-
7088-
s = s*scale; // scale KQ value
7089-
7090-
if (logit_softcap != 0.0f) {
7091-
s = logit_softcap*tanhf(s);
7173+
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
7174+
const float mv = mp_value_base * slope[i_gqa];
7175+
ggml_compute_forward_flash_attn_ext_f16_one_QKV(
7176+
Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type,
7177+
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
70927178
}
7179+
}
70937180

7094-
s += mv; // apply mask
7095-
7096-
const float Mold = M;
7097-
7098-
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
7099-
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
7100-
7101-
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
7102-
7181+
for (int i = 0; i < n_gqa; ++i) {
71037182
if (v->type == GGML_TYPE_F16) {
7104-
if (s > M) {
7105-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
7106-
M = s;
7107-
ms = expf(Mold - M);
7108-
7109-
// V = V*expf(Mold - M)
7110-
ggml_vec_scale_f16(DV, VKQ16, ms);
7111-
} else {
7112-
// no new maximum, ms == 1.0f, vs != 1.0f
7113-
vs = expf(s - M);
7114-
}
7115-
7116-
// V += v*expf(s - M)
7117-
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
7118-
} else {
7119-
if (s > M) {
7120-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
7121-
M = s;
7122-
ms = expf(Mold - M);
7123-
7124-
// V = V*expf(Mold - M)
7125-
ggml_vec_scale_f32(DV, VKQ32, ms);
7126-
} else {
7127-
// no new maximum, ms == 1.0f, vs != 1.0f
7128-
vs = expf(s - M);
7129-
}
7130-
7131-
// V += v*expf(s - M)
7132-
if (v_to_float) {
7133-
v_to_float(v_data, V32, DV);
7134-
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
7135-
} else {
7136-
// V is F32
7137-
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
7183+
for (int64_t d = 0; d < DV; ++d) {
7184+
VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
71387185
}
71397186
}
71407187

7141-
S = S*ms + vs; // scale and increment sum with partial sum
7142-
}
7188+
// V /= S
7189+
const float S_inv = 1.0f/S[i];
7190+
ggml_vec_scale_f32(DV, VKQ32[i], S_inv);
71437191

7144-
if (v->type == GGML_TYPE_F16) {
7145-
for (int64_t d = 0; d < DV; ++d) {
7146-
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
7147-
}
7148-
}
7192+
// dst indices
7193+
const int i1 = iq1;
7194+
const int i2 = iq2 + i;
7195+
const int i3 = iq3;
71497196

7150-
// V /= S
7151-
const float S_inv = 1.0f/S;
7152-
ggml_vec_scale_f32(DV, VKQ32, S_inv);
7197+
// original
7198+
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
71537199

7154-
// dst indices
7155-
const int i1 = iq1;
7156-
const int i2 = iq2;
7157-
const int i3 = iq3;
7158-
7159-
// original
7160-
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7161-
7162-
// permute(0, 2, 1, 3)
7163-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
7200+
// permute(0, 2, 1, 3)
7201+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
7202+
}
71647203
}
71657204
}
71667205

0 commit comments

Comments
 (0)