Skip to content

Commit 4f7d698

Browse files
committed
format
Signed-off-by: ZelinMa557 <3388706467@qq.com>
1 parent 54b99d2 commit 4f7d698

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7046,16 +7046,16 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70467046
V32 [i_gqa] = (VKQ32[i_gqa] + 1*DV);
70477047
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*DV);
70487048
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*DV);
7049-
7049+
70507050
if (v->type == GGML_TYPE_F16) {
70517051
memset(VKQ16[i_gqa], 0, DV*sizeof(ggml_fp16_t));
70527052
} else {
70537053
memset(VKQ32[i_gqa], 0, DV*sizeof(float));
70547054
}
7055-
7055+
70567056
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
70577057
q_to_vec_dot(pq, Q_q[i_gqa], DK);
7058-
7058+
70597059
const uint32_t h = iq2 + i_gqa;
70607060
slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
70617061
}
@@ -7083,7 +7083,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70837083
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
70847084
const float mv = mp_value_base * slope[i_gqa];
70857085
ggml_compute_forward_flash_attn_ext_f16_one_QKV(
7086-
Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type,
7086+
Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type,
70877087
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
70887088
}
70897089
}
@@ -7094,19 +7094,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70947094
VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
70957095
}
70967096
}
7097-
7097+
70987098
// V /= S
70997099
const float S_inv = 1.0f/S[i];
71007100
ggml_vec_scale_f32(DV, VKQ32[i], S_inv);
7101-
7101+
71027102
// dst indices
71037103
const int i1 = iq1;
71047104
const int i2 = iq2 + i;
71057105
const int i3 = iq3;
7106-
7106+
71077107
// original
71087108
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
7109-
7109+
71107110
// permute(0, 2, 1, 3)
71117111
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
71127112
}

0 commit comments

Comments
 (0)