Skip to content

Commit 5db47b8

Browse files
committed
mamba : apply suggestions from code review
* mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
1 parent 5988918 commit 5db47b8

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

ggml.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6093,8 +6093,8 @@ struct ggml_tensor * ggml_ssm_scan(
60936093
GGML_ASSERT(ggml_is_contiguous(dt));
60946094
GGML_ASSERT(ggml_is_contiguous(A));
60956095
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6096-
ggml_are_same_shape(x, dt);
6097-
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
6096+
GGML_ASSERT(ggml_are_same_shape(x, dt));
6097+
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
60986098

60996099
{
61006100
const int64_t d_state = s->ne[0];
@@ -6111,6 +6111,7 @@ struct ggml_tensor * ggml_ssm_scan(
61116111
bool is_node = false;
61126112

61136113
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
6114+
GGML_ASSERT(false); // TODO: implement
61146115
is_node = true;
61156116
}
61166117

@@ -14681,7 +14682,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1468114682

1468214683
// first batch
1468314684
{
14684-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14685+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
1468514686
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
1468614687
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
1468714688
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14695,14 +14696,14 @@ static void ggml_compute_forward_ssm_scan_f32(
1469514696
for (int i0 = 0; i0 < nc; ++i0) {
1469614697
int i = i0 + i1*nc;
1469714698
// ssm_state * dA + dB * x
14698-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14699+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1469914700
}
1470014701
}
1470114702
}
1470214703

1470314704
// compute state for rest of tokens, previous state comes from dest
1470414705
for (int i2 = 1; i2 < n_t; ++i2) {
14705-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14706+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
1470614707
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
1470714708
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
1470814709
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14716,7 +14717,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1471614717
for (int i0 = 0; i0 < nc; ++i0) {
1471714718
int i = i0 + i1*nc;
1471814719
// ssm_state * dA + dB * x
14719-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14720+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1472014721
}
1472114722
}
1472214723
}

llama.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8001,17 +8001,10 @@ struct llm_build_context {
80018001
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
80028002
ggml_view_tensor(ctx0, kv_self.v_l[il])));
80038003

8004-
struct ggml_tensor * y;
8005-
if (n_tok == 1) {
8006-
// row-wise dot product ("dn,n->d")
8007-
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
8008-
y = ggml_mul_mat(ctx0, ssm_state, C);
8009-
} else {
8010-
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
8011-
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
8012-
// => {d_inner, n_tok}
8013-
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
8014-
}
8004+
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
8005+
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
8006+
// => {d_inner, n_tok}
8007+
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
80158008
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
80168009
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
80178010
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

0 commit comments

Comments
 (0)