@@ -6087,14 +6087,15 @@ struct ggml_tensor * ggml_ssm_scan(
6087
6087
struct ggml_tensor * x,
6088
6088
struct ggml_tensor * dt,
6089
6089
struct ggml_tensor * A,
6090
- struct ggml_tensor * B) {
6090
+ struct ggml_tensor * B,
6091
+ struct ggml_tensor * C) {
6091
6092
GGML_ASSERT(ggml_is_contiguous(s));
6092
6093
GGML_ASSERT(ggml_is_contiguous(x));
6093
6094
GGML_ASSERT(ggml_is_contiguous(dt));
6094
6095
GGML_ASSERT(ggml_is_contiguous(A));
6095
6096
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6097
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
6096
6098
GGML_ASSERT(ggml_are_same_shape(x, dt));
6097
- GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
6098
6099
6099
6100
{
6100
6101
const int64_t d_state = s->ne[0];
@@ -6106,6 +6107,8 @@ struct ggml_tensor * ggml_ssm_scan(
6106
6107
GGML_ASSERT(A->ne[1] == d_inner);
6107
6108
GGML_ASSERT(B->ne[0] == d_state);
6108
6109
GGML_ASSERT(B->ne[1] == n_tokens);
6110
+ GGML_ASSERT(C->ne[0] == d_state);
6111
+ GGML_ASSERT(C->ne[1] == n_tokens);
6109
6112
}
6110
6113
6111
6114
bool is_node = false;
@@ -6115,7 +6118,8 @@ struct ggml_tensor * ggml_ssm_scan(
6115
6118
is_node = true;
6116
6119
}
6117
6120
6118
- struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
6121
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6122
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
6119
6123
6120
6124
result->op = GGML_OP_SSM_SCAN;
6121
6125
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6124,6 +6128,7 @@ struct ggml_tensor * ggml_ssm_scan(
6124
6128
result->src[2] = dt;
6125
6129
result->src[3] = A;
6126
6130
result->src[4] = B;
6131
+ result->src[5] = C;
6127
6132
6128
6133
return result;
6129
6134
}
@@ -14650,6 +14655,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14650
14655
const struct ggml_tensor * src2, // dt
14651
14656
const struct ggml_tensor * src3, // A
14652
14657
const struct ggml_tensor * src4, // B
14658
+ const struct ggml_tensor * src5, // C
14653
14659
struct ggml_tensor * dst) {
14654
14660
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14655
14661
return;
@@ -14658,67 +14664,84 @@ static void ggml_compute_forward_ssm_scan_f32(
14658
14664
const int ith = params->ith;
14659
14665
const int nth = params->nth;
14660
14666
14661
- const int64_t nc = src0->ne[0];
14667
+ const int64_t nc = src0->ne[0]; // d_state
14668
+ const int64_t nr = src0->ne[1]; // d_inner
14662
14669
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
14663
- const int64_t nr0 = ggml_nrows(src0);
14664
14670
14665
- GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
14671
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
14666
14672
GGML_ASSERT(src0->nb[0] == sizeof(float));
14667
14673
GGML_ASSERT(src1->nb[0] == sizeof(float));
14668
14674
GGML_ASSERT(src2->nb[0] == sizeof(float));
14669
14675
GGML_ASSERT(src3->nb[0] == sizeof(float));
14670
14676
GGML_ASSERT(src4->nb[0] == sizeof(float));
14671
- // allow merging multiple rows in the same vec operation
14677
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
14678
+ // required for the dot product between s and C
14672
14679
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14673
- GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
14680
+ // required to get correct offset for state destination
14681
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
14674
14682
14675
14683
// rows per thread
14676
- const int dr = (nr0 + nth - 1)/nth;
14684
+ const int dr = (nr + nth - 1)/nth;
14677
14685
14678
14686
// row range for this thread
14679
14687
const int ir0 = dr*ith;
14680
- const int ir1 = MIN(ir0 + dr, nr0 );
14688
+ const int ir1 = MIN(ir0 + dr, nr );
14681
14689
const int ir = ir1 - ir0;
14682
14690
14683
- // first batch
14691
+ // first token in the batch
14684
14692
{
14685
- float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
14686
- float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14687
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14688
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14689
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14690
- float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14693
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14694
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14695
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
14696
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14697
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14698
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14699
+ float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14700
+ float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
14691
14701
// d_inner
14692
14702
for (int i1 = 0; i1 < ir; ++i1) {
14693
14703
float dt_soft_plus = log1pf(expf(dt[i1]));
14694
14704
float x_dt = x[i1] * dt_soft_plus;
14705
+ float sumf = 0.0f;
14695
14706
// d_state
14696
14707
for (int i0 = 0; i0 < nc; ++i0) {
14697
14708
int i = i0 + i1*nc;
14698
- // ssm_state * dA + dB * x
14699
- pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14709
+ // state = prev_state * dA + dB * x
14710
+ float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14711
+ // y = rowwise_dotprod(state, C)
14712
+ sumf += state*C[i0];
14713
+ // FIXME: handle simultaneous sequences
14714
+ s[i] = state;
14700
14715
}
14716
+ y[i1] = sumf;
14701
14717
}
14702
14718
}
14703
14719
14704
- // compute state for rest of tokens, previous state comes from dest
14720
+ // rest of the batch, state comes from previous one which was stored in destination
14705
14721
for (int i2 = 1; i2 < n_t; ++i2) {
14706
- float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14707
- float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14708
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14709
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
14710
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14711
- float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14722
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14723
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14724
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14725
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
14726
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14727
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14728
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
14712
14729
// d_inner
14713
14730
for (int i1 = 0; i1 < ir; ++i1) {
14714
14731
float dt_soft_plus = log1pf(expf(dt[i1]));
14715
14732
float x_dt = x[i1] * dt_soft_plus;
14733
+ float sumf = 0.0f;
14716
14734
// d_state
14717
14735
for (int i0 = 0; i0 < nc; ++i0) {
14718
14736
int i = i0 + i1*nc;
14719
- // ssm_state * dA + dB * x
14720
- pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14737
+ // state = prev_state * dA + dB * x
14738
+ float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14739
+ // y = rowwise_dotprod(state, C)
14740
+ sumf += state*C[i0];
14741
+ // FIXME: handle simultaneous sequences
14742
+ s[i] = state;
14721
14743
}
14744
+ y[i1] = sumf;
14722
14745
}
14723
14746
}
14724
14747
}
@@ -14730,11 +14753,12 @@ static void ggml_compute_forward_ssm_scan(
14730
14753
const struct ggml_tensor * src2,
14731
14754
const struct ggml_tensor * src3,
14732
14755
const struct ggml_tensor * src4,
14756
+ const struct ggml_tensor * src5,
14733
14757
struct ggml_tensor * dst) {
14734
14758
switch (src0->type) {
14735
14759
case GGML_TYPE_F32:
14736
14760
{
14737
- ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
14761
+ ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
14738
14762
} break;
14739
14763
default:
14740
14764
{
@@ -15796,7 +15820,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15796
15820
} break;
15797
15821
case GGML_OP_SSM_SCAN:
15798
15822
{
15799
- ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
15823
+ ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor );
15800
15824
} break;
15801
15825
case GGML_OP_WIN_PART:
15802
15826
{
0 commit comments