@@ -7046,16 +7046,16 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7046
7046
V32 [i_gqa] = (VKQ32[i_gqa] + 1 *DV);
7047
7047
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1 *DV);
7048
7048
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2 *DV);
7049
-
7049
+
7050
7050
if (v->type == GGML_TYPE_F16) {
7051
7051
memset (VKQ16[i_gqa], 0 , DV*sizeof (ggml_fp16_t ));
7052
7052
} else {
7053
7053
memset (VKQ32[i_gqa], 0 , DV*sizeof (float ));
7054
7054
}
7055
-
7055
+
7056
7056
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
7057
7057
q_to_vec_dot (pq, Q_q[i_gqa], DK);
7058
-
7058
+
7059
7059
const uint32_t h = iq2 + i_gqa;
7060
7060
slope[i_gqa] = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
7061
7061
}
@@ -7083,7 +7083,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7083
7083
for (int i_gqa = 0 ; i_gqa < n_gqa; ++i_gqa) {
7084
7084
const float mv = mp_value_base * slope[i_gqa];
7085
7085
ggml_compute_forward_flash_attn_ext_f16_one_QKV (
7086
- Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type ,
7086
+ Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type ,
7087
7087
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
7088
7088
}
7089
7089
}
@@ -7094,19 +7094,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7094
7094
VKQ32[i][d] = GGML_FP16_TO_FP32 (VKQ16[i][d]);
7095
7095
}
7096
7096
}
7097
-
7097
+
7098
7098
// V /= S
7099
7099
const float S_inv = 1 .0f /S[i];
7100
7100
ggml_vec_scale_f32 (DV, VKQ32[i], S_inv);
7101
-
7101
+
7102
7102
// dst indices
7103
7103
const int i1 = iq1;
7104
7104
const int i2 = iq2 + i;
7105
7105
const int i3 = iq3;
7106
-
7106
+
7107
7107
// original
7108
7108
// memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7109
-
7109
+
7110
7110
// permute(0, 2, 1, 3)
7111
7111
memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
7112
7112
}
0 commit comments