@@ -6054,14 +6054,15 @@ struct ggml_tensor * ggml_ssm_scan(
6054
6054
struct ggml_tensor * x,
6055
6055
struct ggml_tensor * dt,
6056
6056
struct ggml_tensor * A,
6057
- struct ggml_tensor * B) {
6057
+ struct ggml_tensor * B,
6058
+ struct ggml_tensor * C) {
6058
6059
GGML_ASSERT(ggml_is_contiguous(s));
6059
6060
GGML_ASSERT(ggml_is_contiguous(x));
6060
6061
GGML_ASSERT(ggml_is_contiguous(dt));
6061
6062
GGML_ASSERT(ggml_is_contiguous(A));
6062
6063
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6064
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
6063
6065
GGML_ASSERT(ggml_are_same_shape(x, dt));
6064
- GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
6065
6066
6066
6067
{
6067
6068
const int64_t d_state = s->ne[0];
@@ -6073,6 +6074,8 @@ struct ggml_tensor * ggml_ssm_scan(
6073
6074
GGML_ASSERT(A->ne[1] == d_inner);
6074
6075
GGML_ASSERT(B->ne[0] == d_state);
6075
6076
GGML_ASSERT(B->ne[1] == n_tokens);
6077
+ GGML_ASSERT(C->ne[0] == d_state);
6078
+ GGML_ASSERT(C->ne[1] == n_tokens);
6076
6079
}
6077
6080
6078
6081
bool is_node = false;
@@ -6082,7 +6085,8 @@ struct ggml_tensor * ggml_ssm_scan(
6082
6085
is_node = true;
6083
6086
}
6084
6087
6085
- struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
6088
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6089
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
6086
6090
6087
6091
result->op = GGML_OP_SSM_SCAN;
6088
6092
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6091,6 +6095,7 @@ struct ggml_tensor * ggml_ssm_scan(
6091
6095
result->src[2] = dt;
6092
6096
result->src[3] = A;
6093
6097
result->src[4] = B;
6098
+ result->src[5] = C;
6094
6099
6095
6100
return result;
6096
6101
}
@@ -14609,6 +14614,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14609
14614
const struct ggml_tensor * src2, // dt
14610
14615
const struct ggml_tensor * src3, // A
14611
14616
const struct ggml_tensor * src4, // B
14617
+ const struct ggml_tensor * src5, // C
14612
14618
struct ggml_tensor * dst) {
14613
14619
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14614
14620
return;
@@ -14617,67 +14623,84 @@ static void ggml_compute_forward_ssm_scan_f32(
14617
14623
const int ith = params->ith;
14618
14624
const int nth = params->nth;
14619
14625
14620
- const int64_t nc = src0->ne[0];
14626
+ const int64_t nc = src0->ne[0]; // d_state
14627
+ const int64_t nr = src0->ne[1]; // d_inner
14621
14628
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
14622
- const int64_t nr0 = ggml_nrows(src0);
14623
14629
14624
- GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
14630
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
14625
14631
GGML_ASSERT(src0->nb[0] == sizeof(float));
14626
14632
GGML_ASSERT(src1->nb[0] == sizeof(float));
14627
14633
GGML_ASSERT(src2->nb[0] == sizeof(float));
14628
14634
GGML_ASSERT(src3->nb[0] == sizeof(float));
14629
14635
GGML_ASSERT(src4->nb[0] == sizeof(float));
14630
- // allow merging multiple rows in the same vec operation
14636
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
14637
+ // required for the dot product between s and C
14631
14638
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14632
- GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
14639
+ // required to get correct offset for state destination
14640
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
14633
14641
14634
14642
// rows per thread
14635
- const int dr = (nr0 + nth - 1)/nth;
14643
+ const int dr = (nr + nth - 1)/nth;
14636
14644
14637
14645
// row range for this thread
14638
14646
const int ir0 = dr*ith;
14639
- const int ir1 = MIN(ir0 + dr, nr0 );
14647
+ const int ir1 = MIN(ir0 + dr, nr );
14640
14648
const int ir = ir1 - ir0;
14641
14649
14642
- // first batch
14650
+ // first token in the batch
14643
14651
{
14644
- float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
14645
- float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14646
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14647
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14648
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14649
- float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14652
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14653
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14654
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
14655
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14656
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14657
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14658
+ float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14659
+ float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
14650
14660
// d_inner
14651
14661
for (int i1 = 0; i1 < ir; ++i1) {
14652
14662
float dt_soft_plus = log1pf(expf(dt[i1]));
14653
14663
float x_dt = x[i1] * dt_soft_plus;
14664
+ float sumf = 0.0f;
14654
14665
// d_state
14655
14666
for (int i0 = 0; i0 < nc; ++i0) {
14656
14667
int i = i0 + i1*nc;
14657
- // ssm_state * dA + dB * x
14658
- pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14668
+ // state = prev_state * dA + dB * x
14669
+ float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14670
+ // y = rowwise_dotprod(state, C)
14671
+ sumf += state*C[i0];
14672
+ // FIXME: handle simultaneous sequences
14673
+ s[i] = state;
14659
14674
}
14675
+ y[i1] = sumf;
14660
14676
}
14661
14677
}
14662
14678
14663
- // compute state for rest of tokens, previous state comes from dest
14679
+ // rest of the batch, state comes from previous one which was stored in destination
14664
14680
for (int i2 = 1; i2 < n_t; ++i2) {
14665
- float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14666
- float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14667
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14668
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
14669
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14670
- float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14681
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14682
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14683
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14684
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
14685
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14686
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14687
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
14671
14688
// d_inner
14672
14689
for (int i1 = 0; i1 < ir; ++i1) {
14673
14690
float dt_soft_plus = log1pf(expf(dt[i1]));
14674
14691
float x_dt = x[i1] * dt_soft_plus;
14692
+ float sumf = 0.0f;
14675
14693
// d_state
14676
14694
for (int i0 = 0; i0 < nc; ++i0) {
14677
14695
int i = i0 + i1*nc;
14678
- // ssm_state * dA + dB * x
14679
- pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14696
+ // state = prev_state * dA + dB * x
14697
+ float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14698
+ // y = rowwise_dotprod(state, C)
14699
+ sumf += state*C[i0];
14700
+ // FIXME: handle simultaneous sequences
14701
+ s[i] = state;
14680
14702
}
14703
+ y[i1] = sumf;
14681
14704
}
14682
14705
}
14683
14706
}
@@ -14689,11 +14712,12 @@ static void ggml_compute_forward_ssm_scan(
14689
14712
const struct ggml_tensor * src2,
14690
14713
const struct ggml_tensor * src3,
14691
14714
const struct ggml_tensor * src4,
14715
+ const struct ggml_tensor * src5,
14692
14716
struct ggml_tensor * dst) {
14693
14717
switch (src0->type) {
14694
14718
case GGML_TYPE_F32:
14695
14719
{
14696
- ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
14720
+ ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
14697
14721
} break;
14698
14722
default:
14699
14723
{
@@ -15752,7 +15776,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15752
15776
} break;
15753
15777
case GGML_OP_SSM_SCAN:
15754
15778
{
15755
- ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
15779
+ ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor );
15756
15780
} break;
15757
15781
case GGML_OP_WIN_PART:
15758
15782
{
0 commit comments