@@ -7493,74 +7493,151 @@ void ggml_compute_forward_ssm_conv(
7493
7493
static void ggml_compute_forward_ssm_scan_f32 (
7494
7494
const ggml_compute_params * params,
7495
7495
ggml_tensor * dst) {
7496
- const ggml_tensor * src0 = dst->src [0 ]; // s
7497
- const ggml_tensor * src1 = dst->src [1 ]; // x
7498
- const ggml_tensor * src2 = dst->src [2 ]; // dt
7499
- const ggml_tensor * src3 = dst->src [3 ]; // A
7500
- const ggml_tensor * src4 = dst->src [4 ]; // B
7501
- const ggml_tensor * src5 = dst->src [5 ]; // C
7496
+ const ggml_tensor * src0 = dst->src [0 ]; // s {d_state, dim, n_head, n_seqs+}
7497
+ const ggml_tensor * src1 = dst->src [1 ]; // x {dim, n_head, n_seq_tokens, n_seqs}
7498
+ const ggml_tensor * src2 = dst->src [2 ]; // dt {n_head, n_seq_tokens, n_seqs}
7499
+ const ggml_tensor * src3 = dst->src [3 ]; // A {d_state, n_head} or {1, n_head}
7500
+ const ggml_tensor * src4 = dst->src [4 ]; // B {d_state, n_group, n_seq_tokens, n_seqs}
7501
+ const ggml_tensor * src5 = dst->src [5 ]; // C {d_state, n_group, n_seq_tokens, n_seqs}
7502
+ const ggml_tensor * src6 = dst->src [6 ]; // ids {n_seqs}
7502
7503
7503
7504
const int ith = params->ith ;
7504
7505
const int nth = params->nth ;
7505
7506
7506
- const int64_t nc = src0->ne [0 ]; // d_state
7507
- const int64_t nr = src0->ne [1 ]; // d_inner
7508
- const int64_t n_t = src1->ne [1 ]; // number of tokens per sequence
7509
- const int64_t n_s = src0->ne [2 ]; // number of sequences in the batch
7507
+ const int64_t nc = src0->ne [0 ]; // d_state
7508
+ const int64_t nr = src0->ne [1 ]; // dim
7509
+ const int64_t nh = src1->ne [1 ]; // n_head
7510
+ const int64_t ng = src4->ne [1 ];
7511
+ const int64_t nt = src1->ne [2 ]; // number of tokens per sequence
7512
+ const int64_t ns = src1->ne [3 ]; // number of sequences in the batch
7510
7513
7511
- GGML_ASSERT (ggml_nelements (src1) + ggml_nelements (src0) == ggml_nelements (dst));
7514
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
7515
+ const int64_t s_off = ggml_nelements (src1) * ggml_element_size (src1);
7516
+
7517
+ GGML_ASSERT (ggml_nelements (src1) + nc*nr*nh*ns == ggml_nelements (dst));
7512
7518
GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
7513
7519
GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
7514
7520
GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
7515
7521
GGML_ASSERT (src3->nb [0 ] == sizeof (float ));
7516
7522
GGML_ASSERT (src4->nb [0 ] == sizeof (float ));
7517
7523
GGML_ASSERT (src5->nb [0 ] == sizeof (float ));
7518
- // required for the dot product between s and C
7519
- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (float ));
7520
- // required for per-sequence offsets for states
7521
- GGML_ASSERT (src0->nb [2 ] == src0->ne [0 ]*src0->ne [1 ]*sizeof (float ));
7522
- // required to get correct offset for state destination (i.e. src1->nb[3])
7523
- GGML_ASSERT (src1->nb [3 ] == src1->ne [0 ]*src1->ne [1 ]*src1->ne [2 ]*sizeof (float ));
7524
+ GGML_ASSERT (src6->nb [0 ] == sizeof (int32_t ));
7525
+ // allows optimizing the modulo since n_group should be a power of 2
7526
+ GGML_ASSERT ((ng & -ng) == ng);
7527
+
7528
+ // heads per thread
7529
+ const int dh = (nh + nth - 1 )/nth;
7530
+
7531
+ // head range for this thread
7532
+ const int ih0 = dh*ith;
7533
+ const int ih1 = MIN (ih0 + dh, nh);
7534
+
7535
+ const int32_t * ids = (const int32_t *) src6->data ;
7536
+
7537
+ for (int i3 = 0 ; i3 < ns; ++i3) {
7538
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb [3 ])); // {d_state, dim, nh, ns}
7539
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb [3 ]) + s_off); // {d_state, dim, nh, ns}
7540
+
7541
+ for (int i2 = 0 ; i2 < nt; ++i2) {
7542
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb [2 ]) + i3*(src1->nb [3 ])); // {dim, nh, nt, ns}
7543
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {nh, nt, ns}
7544
+ const float * A = (const float *) ((const char *) src3->data ); // {d_state, nh} or {1, nh}
7545
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [2 ]) + i3*(src4->nb [3 ])); // {d_state, ng, nt, ns}
7546
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [2 ]) + i3*(src5->nb [3 ])); // {d_state, ng, nt, ns}
7547
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof (float )) + i3*(nt*nh*nr*sizeof (float ))); // {dim, nh, nt, ns}
7548
+
7549
+ if (src3->ne [0 ] == 1 ) {
7550
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
7551
+
7552
+ // n_head
7553
+ for (int h = ih0; h < ih1; ++h) {
7554
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7555
+ const float dt_soft_plus = dt[h] <= 20 .0f ? log1pf (expf (dt[h])) : dt[h];
7556
+ const float dA = expf (dt_soft_plus * A[h]);
7557
+
7558
+ // dim
7559
+ for (int i1 = 0 ; i1 < nr; ++i1) {
7560
+ const int ii = i1 + h*nr;
7561
+ const float x_dt = x[ii] * dt_soft_plus;
7562
+ float sumf = 0 .0f ;
7563
+ #if defined(GGML_SIMD)
7564
+ const int np = (nc & ~(GGML_F32_STEP - 1 ));
7524
7565
7525
- // rows per thread
7526
- const int dr = (nr + nth - 1 )/nth;
7566
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
7527
7567
7528
- // row range for this thread
7529
- const int ir0 = dr*ith;
7530
- const int ir1 = MIN (ir0 + dr, nr);
7531
- const int ir = ir1 - ir0;
7568
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1 (dA);
7569
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1 (x_dt);
7532
7570
7533
- for (int i3 = 0 ; i3 < n_s; ++i3) {
7534
- for (int i2 = 0 ; i2 < n_t ; ++i2) {
7535
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7536
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7537
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7538
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7539
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7540
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7541
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7542
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7543
-
7544
- // use the output as the source for the next token-wise iterations
7545
- if (i2 > 0 ) { s0 = s; }
7571
+ GGML_F32_VEC ax[GGML_F32_ARR];
7572
+ GGML_F32_VEC ay[GGML_F32_ARR];
7573
+ GGML_F32_VEC az[GGML_F32_ARR];
7546
7574
7547
- // d_inner
7548
- for (int i1 = 0 ; i1 < ir; ++i1) {
7549
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7550
- float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7551
- float x_dt = x[i1] * dt_soft_plus;
7552
- float sumf = 0 .0f ;
7553
- // d_state
7554
- for (int i0 = 0 ; i0 < nc; ++i0) {
7555
- int i = i0 + i1*nc;
7556
- // state = prev_state * dA + dB * x
7557
- float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7558
- // y = rowwise_dotprod(state, C)
7559
- sumf += state * C[i0];
7560
- s[i] = state;
7575
+ for (int i = 0 ; i < np; i += GGML_F32_STEP) {
7576
+ for (int j = 0 ; j < GGML_F32_ARR; j++) {
7577
+ ax[j] = GGML_F32_VEC_LOAD (s0 + i + j*GGML_F32_EPR + ii*nc);
7578
+ ay[j] = GGML_F32_VEC_LOAD (B + i + j*GGML_F32_EPR + (h & (ng - 1 ))*nc);
7579
+ az[j] = GGML_F32_VEC_LOAD (C + i + j*GGML_F32_EPR + (h & (ng - 1 ))*nc);
7580
+
7581
+ ax[j] = GGML_F32_VEC_MUL (ax[j], adA);
7582
+ ay[j] = GGML_F32_VEC_MUL (ay[j], axdt);
7583
+
7584
+ ax[j] = GGML_F32_VEC_ADD (ax[j], ay[j]);
7585
+
7586
+ sum[j] = GGML_F32_VEC_FMA (sum[j], ax[j], az[j]);
7587
+
7588
+ GGML_F32_VEC_STORE (s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
7589
+ }
7590
+ }
7591
+
7592
+ // reduce sum0..sum3 to sum0
7593
+ GGML_F32_VEC_REDUCE (sumf, sum);
7594
+ #else
7595
+ const int np = 0 ;
7596
+ #endif
7597
+ // d_state
7598
+ for (int i0 = np; i0 < nc; ++i0) {
7599
+ const int i = i0 + ii*nc;
7600
+ const int ig = i0 + (h & (ng - 1 ))*nc;
7601
+ // state = prev_state * dA + dB * x
7602
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
7603
+ // y = rowwise_dotprod(state, C)
7604
+ sumf += state * C[ig];
7605
+ s[i] = state;
7606
+ }
7607
+ y[ii] = sumf;
7608
+ }
7609
+ }
7610
+ } else {
7611
+ // Mamba-1 has an element-wise decay factor for the states
7612
+
7613
+ // n_head
7614
+ for (int h = ih0; h < ih1; ++h) {
7615
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7616
+ const float dt_soft_plus = dt[h] <= 20 .0f ? log1pf (expf (dt[h])) : dt[h];
7617
+
7618
+ // dim
7619
+ for (int i1 = 0 ; i1 < nr; ++i1) {
7620
+ const int ii = i1 + h*nr;
7621
+ const float x_dt = x[ii] * dt_soft_plus;
7622
+ float sumf = 0 .0f ;
7623
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
7624
+ // and also because expf is used within the loop.
7625
+ // d_state
7626
+ for (int i0 = 0 ; i0 < nc; ++i0) {
7627
+ const int i = i0 + ii*nc;
7628
+ const int ig = i0 + (h & (ng - 1 ))*nc;
7629
+ // state = prev_state * dA + dB * x
7630
+ const float state = (s0[i] * expf (dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
7631
+ // y = rowwise_dotprod(state, C)
7632
+ sumf += state * C[ig];
7633
+ s[i] = state;
7634
+ }
7635
+ y[ii] = sumf;
7636
+ }
7561
7637
}
7562
- y[i1] = sumf;
7563
7638
}
7639
+ // use the output as the source when it's not the first token-wise iteration
7640
+ s0 = s;
7564
7641
}
7565
7642
}
7566
7643
}
0 commit comments