@@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32(
7664
7664
const float x_dt = x[ii] * dt_soft_plus;
7665
7665
float sumf = 0 .0f ;
7666
7666
#if defined(GGML_SIMD)
7667
+ #if defined(__ARM_FEATURE_SVE)
7668
+ const int ggml_f32_epr = svcntw ();
7669
+ const int ggml_f32_step = 1 * ggml_f32_epr;
7670
+
7671
+ const int np = (nc & ~(ggml_f32_step - 1 ));
7672
+
7673
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
7674
+
7675
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1 (dA);
7676
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1 (x_dt);
7677
+
7678
+ for (int i = 0 ; i < np; i += ggml_f32_step) {
7679
+ // TODO: maybe unroll more?
7680
+ for (int j = 0 ; j < 1 ; j++) {
7681
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD (s0 + i + j*ggml_f32_epr + ii*nc);
7682
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + (h & (ng - 1 ))*nc);
7683
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + (h & (ng - 1 ))*nc);
7684
+
7685
+ t0 = GGML_F32_VEC_MUL (t0, adA);
7686
+ t1 = GGML_F32_VEC_MUL (t1, axdt);
7687
+
7688
+ t0 = GGML_F32_VEC_ADD (t0, t1);
7689
+
7690
+ sum = GGML_F32_VEC_FMA (sum, t0, t2);
7691
+
7692
+ GGML_F32_VEC_STORE (s + i + j*ggml_f32_epr + ii*nc, t0);
7693
+ }
7694
+ }
7695
+
7696
+ sumf = GGML_F32xt_REDUCE_ONE (sum);
7697
+ #else
7667
7698
const int np = (nc & ~(GGML_F32_STEP - 1 ));
7668
7699
7669
7700
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32(
7694
7725
7695
7726
// reduce sum0..sum3 to sum0
7696
7727
GGML_F32_VEC_REDUCE (sumf, sum);
7728
+ #endif
7697
7729
#else
7698
7730
const int np = 0 ;
7699
7731
#endif
@@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32(
7722
7754
for (int i1 = 0 ; i1 < nr; ++i1) {
7723
7755
const int ii = i1 + h*nr;
7724
7756
const float x_dt = x[ii] * dt_soft_plus;
7725
- #ifdef __ARM_FEATURE_SVE
7757
+ #if defined( __ARM_FEATURE_SVE)
7726
7758
svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
7727
7759
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
7728
7760
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
0 commit comments