Skip to content

Commit 757aa62

Browse files
committed
ggml : fix mamba2 ssm scan when compiled with SVE
1 parent 2fa5f2c commit 757aa62

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32(
76647664
const float x_dt = x[ii] * dt_soft_plus;
76657665
float sumf = 0.0f;
76667666
#if defined(GGML_SIMD)
7667+
#if defined(__ARM_FEATURE_SVE)
7668+
const int ggml_f32_epr = svcntw();
7669+
const int ggml_f32_step = 1 * ggml_f32_epr;
7670+
7671+
const int np = (nc & ~(ggml_f32_step - 1));
7672+
7673+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
7674+
7675+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
7676+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
7677+
7678+
for (int i = 0; i < np; i += ggml_f32_step) {
7679+
// TODO: maybe unroll more?
7680+
for (int j = 0; j < 1; j++) {
7681+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
7682+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
7683+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
7684+
7685+
t0 = GGML_F32_VEC_MUL(t0, adA);
7686+
t1 = GGML_F32_VEC_MUL(t1, axdt);
7687+
7688+
t0 = GGML_F32_VEC_ADD(t0, t1);
7689+
7690+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
7691+
7692+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
7693+
}
7694+
}
7695+
7696+
sumf = GGML_F32xt_REDUCE_ONE(sum);
7697+
#else
76677698
const int np = (nc & ~(GGML_F32_STEP - 1));
76687699

76697700
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32(
76947725

76957726
// reduce sum0..sum3 to sum0
76967727
GGML_F32_VEC_REDUCE(sumf, sum);
7728+
#endif
76977729
#else
76987730
const int np = 0;
76997731
#endif
@@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32(
77227754
for (int i1 = 0; i1 < nr; ++i1) {
77237755
const int ii = i1 + h*nr;
77247756
const float x_dt = x[ii] * dt_soft_plus;
7725-
#ifdef __ARM_FEATURE_SVE
7757+
#if defined(__ARM_FEATURE_SVE)
77267758
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
77277759
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
77287760
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;

0 commit comments

Comments
 (0)