@@ -5944,8 +5944,8 @@ struct ggml_tensor * ggml_ssm_scan(
5944
5944
GGML_ASSERT(ggml_is_contiguous(dt));
5945
5945
GGML_ASSERT(ggml_is_contiguous(A));
5946
5946
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
5947
- ggml_are_same_shape(x, dt);
5948
- GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1 ); // the ssm_state should be 2D
5947
+ GGML_ASSERT( ggml_are_same_shape(x, dt) );
5948
+ GGML_ASSERT(ggml_is_matrix(s) ); // the ssm_state should be 2D
5949
5949
5950
5950
{
5951
5951
const int64_t d_state = s->ne[0];
@@ -5962,6 +5962,7 @@ struct ggml_tensor * ggml_ssm_scan(
5962
5962
bool is_node = false;
5963
5963
5964
5964
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
5965
+ GGML_ASSERT(false); // TODO: implement
5965
5966
is_node = true;
5966
5967
}
5967
5968
@@ -14236,7 +14237,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14236
14237
14237
14238
// first batch
14238
14239
{
14239
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14240
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14240
14241
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14241
14242
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14242
14243
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14250,14 +14251,14 @@ static void ggml_compute_forward_ssm_scan_f32(
14250
14251
for (int i0 = 0; i0 < nc; ++i0) {
14251
14252
int i = i0 + i1*nc;
14252
14253
// ssm_state * dA + dB * x
14253
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14254
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14254
14255
}
14255
14256
}
14256
14257
}
14257
14258
14258
14259
// compute state for rest of tokens, previous state comes from dest
14259
14260
for (int i2 = 1; i2 < n_t; ++i2) {
14260
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14261
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14261
14262
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14262
14263
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14263
14264
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14271,7 +14272,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14271
14272
for (int i0 = 0; i0 < nc; ++i0) {
14272
14273
int i = i0 + i1*nc;
14273
14274
// ssm_state * dA + dB * x
14274
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14275
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14275
14276
}
14276
14277
}
14277
14278
}
0 commit comments