Skip to content

Commit 5d8d127

Browse files
committed
mamba : apply suggestions from code review
* mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
1 parent 19cfe40 commit 5d8d127

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

ggml.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5944,8 +5944,8 @@ struct ggml_tensor * ggml_ssm_scan(
59445944
GGML_ASSERT(ggml_is_contiguous(dt));
59455945
GGML_ASSERT(ggml_is_contiguous(A));
59465946
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
5947-
ggml_are_same_shape(x, dt);
5948-
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
5947+
GGML_ASSERT(ggml_are_same_shape(x, dt));
5948+
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
59495949

59505950
{
59515951
const int64_t d_state = s->ne[0];
@@ -5962,6 +5962,7 @@ struct ggml_tensor * ggml_ssm_scan(
59625962
bool is_node = false;
59635963

59645964
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
5965+
GGML_ASSERT(false); // TODO: implement
59655966
is_node = true;
59665967
}
59675968

@@ -14236,7 +14237,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1423614237

1423714238
// first batch
1423814239
{
14239-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14240+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
1424014241
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
1424114242
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
1424214243
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14250,14 +14251,14 @@ static void ggml_compute_forward_ssm_scan_f32(
1425014251
for (int i0 = 0; i0 < nc; ++i0) {
1425114252
int i = i0 + i1*nc;
1425214253
// ssm_state * dA + dB * x
14253-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14254+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1425414255
}
1425514256
}
1425614257
}
1425714258

1425814259
// compute state for rest of tokens, previous state comes from dest
1425914260
for (int i2 = 1; i2 < n_t; ++i2) {
14260-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14261+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
1426114262
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
1426214263
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
1426314264
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14271,7 +14272,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1427114272
for (int i0 = 0; i0 < nc; ++i0) {
1427214273
int i = i0 + i1*nc;
1427314274
// ssm_state * dA + dB * x
14274-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14275+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1427514276
}
1427614277
}
1427714278
}

llama.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7560,17 +7560,10 @@ struct llm_build_context {
75607560
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
75617561
ggml_view_tensor(ctx0, kv_self.v_l[il])));
75627562

7563-
struct ggml_tensor * y;
7564-
if (n_tok == 1) {
7565-
// row-wise dot product ("dn,n->d")
7566-
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
7567-
y = ggml_mul_mat(ctx0, ssm_state, C);
7568-
} else {
7569-
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
7570-
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
7571-
// => {d_inner, n_tok}
7572-
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
7573-
}
7563+
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
7564+
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
7565+
// => {d_inner, n_tok}
7566+
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
75747567
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
75757568
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
75767569
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

0 commit comments

Comments
 (0)