@@ -6093,8 +6093,8 @@ struct ggml_tensor * ggml_ssm_scan(
6093
6093
GGML_ASSERT(ggml_is_contiguous(dt));
6094
6094
GGML_ASSERT(ggml_is_contiguous(A));
6095
6095
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6096
- ggml_are_same_shape(x, dt);
6097
- GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1 ); // the ssm_state should be 2D
6096
+ GGML_ASSERT( ggml_are_same_shape(x, dt) );
6097
+ GGML_ASSERT(ggml_is_matrix(s) ); // the ssm_state should be 2D
6098
6098
6099
6099
{
6100
6100
const int64_t d_state = s->ne[0];
@@ -6111,6 +6111,7 @@ struct ggml_tensor * ggml_ssm_scan(
6111
6111
bool is_node = false;
6112
6112
6113
6113
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
6114
+ GGML_ASSERT(false); // TODO: implement
6114
6115
is_node = true;
6115
6116
}
6116
6117
@@ -14681,7 +14682,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14681
14682
14682
14683
// first batch
14683
14684
{
14684
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14685
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14685
14686
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14686
14687
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14687
14688
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14695,14 +14696,14 @@ static void ggml_compute_forward_ssm_scan_f32(
14695
14696
for (int i0 = 0; i0 < nc; ++i0) {
14696
14697
int i = i0 + i1*nc;
14697
14698
// ssm_state * dA + dB * x
14698
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14699
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14699
14700
}
14700
14701
}
14701
14702
}
14702
14703
14703
14704
// compute state for rest of tokens, previous state comes from dest
14704
14705
for (int i2 = 1; i2 < n_t; ++i2) {
14705
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14706
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14706
14707
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14707
14708
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14708
14709
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14716,7 +14717,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14716
14717
for (int i0 = 0; i0 < nc; ++i0) {
14717
14718
int i = i0 + i1*nc;
14718
14719
// ssm_state * dA + dB * x
14719
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14720
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14720
14721
}
14721
14722
}
14722
14723
}
0 commit comments