Skip to content

Commit 93a8e64

Browse files
committed
mamba : reduce memory usage of ggml_ssm_scan
From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built.
1 parent dff1b70 commit 93a8e64

File tree

3 files changed

+70
-50
lines changed

3 files changed

+70
-50
lines changed

ggml.c

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6054,14 +6054,15 @@ struct ggml_tensor * ggml_ssm_scan(
60546054
struct ggml_tensor * x,
60556055
struct ggml_tensor * dt,
60566056
struct ggml_tensor * A,
6057-
struct ggml_tensor * B) {
6057+
struct ggml_tensor * B,
6058+
struct ggml_tensor * C) {
60586059
GGML_ASSERT(ggml_is_contiguous(s));
60596060
GGML_ASSERT(ggml_is_contiguous(x));
60606061
GGML_ASSERT(ggml_is_contiguous(dt));
60616062
GGML_ASSERT(ggml_is_contiguous(A));
60626063
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6064+
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
60636065
GGML_ASSERT(ggml_are_same_shape(x, dt));
6064-
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
60656066

60666067
{
60676068
const int64_t d_state = s->ne[0];
@@ -6073,6 +6074,8 @@ struct ggml_tensor * ggml_ssm_scan(
60736074
GGML_ASSERT(A->ne[1] == d_inner);
60746075
GGML_ASSERT(B->ne[0] == d_state);
60756076
GGML_ASSERT(B->ne[1] == n_tokens);
6077+
GGML_ASSERT(C->ne[0] == d_state);
6078+
GGML_ASSERT(C->ne[1] == n_tokens);
60766079
}
60776080

60786081
bool is_node = false;
@@ -6082,7 +6085,8 @@ struct ggml_tensor * ggml_ssm_scan(
60826085
is_node = true;
60836086
}
60846087

6085-
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
6088+
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6089+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
60866090

60876091
result->op = GGML_OP_SSM_SCAN;
60886092
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6091,6 +6095,7 @@ struct ggml_tensor * ggml_ssm_scan(
60916095
result->src[2] = dt;
60926096
result->src[3] = A;
60936097
result->src[4] = B;
6098+
result->src[5] = C;
60946099

60956100
return result;
60966101
}
@@ -14609,6 +14614,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1460914614
const struct ggml_tensor * src2, // dt
1461014615
const struct ggml_tensor * src3, // A
1461114616
const struct ggml_tensor * src4, // B
14617+
const struct ggml_tensor * src5, // C
1461214618
struct ggml_tensor * dst) {
1461314619
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
1461414620
return;
@@ -14617,67 +14623,84 @@ static void ggml_compute_forward_ssm_scan_f32(
1461714623
const int ith = params->ith;
1461814624
const int nth = params->nth;
1461914625

14620-
const int64_t nc = src0->ne[0];
14626+
const int64_t nc = src0->ne[0]; // d_state
14627+
const int64_t nr = src0->ne[1]; // d_inner
1462114628
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
14622-
const int64_t nr0 = ggml_nrows(src0);
1462314629

14624-
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
14630+
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
1462514631
GGML_ASSERT(src0->nb[0] == sizeof(float));
1462614632
GGML_ASSERT(src1->nb[0] == sizeof(float));
1462714633
GGML_ASSERT(src2->nb[0] == sizeof(float));
1462814634
GGML_ASSERT(src3->nb[0] == sizeof(float));
1462914635
GGML_ASSERT(src4->nb[0] == sizeof(float));
14630-
// allow merging multiple rows in the same vec operation
14636+
GGML_ASSERT(src5->nb[0] == sizeof(float));
14637+
// required for the dot product between s and C
1463114638
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14632-
GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
14639+
// required to get correct offset for state destination
14640+
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
1463314641

1463414642
// rows per thread
14635-
const int dr = (nr0 + nth - 1)/nth;
14643+
const int dr = (nr + nth - 1)/nth;
1463614644

1463714645
// row range for this thread
1463814646
const int ir0 = dr*ith;
14639-
const int ir1 = MIN(ir0 + dr, nr0);
14647+
const int ir1 = MIN(ir0 + dr, nr);
1464014648
const int ir = ir1 - ir0;
1464114649

14642-
// first batch
14650+
// first token in the batch
1464314651
{
14644-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
14645-
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14646-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14647-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14648-
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14649-
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14652+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14653+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14654+
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
14655+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14656+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14657+
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14658+
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14659+
float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
1465014660
// d_inner
1465114661
for (int i1 = 0; i1 < ir; ++i1) {
1465214662
float dt_soft_plus = log1pf(expf(dt[i1]));
1465314663
float x_dt = x[i1] * dt_soft_plus;
14664+
float sumf = 0.0f;
1465414665
// d_state
1465514666
for (int i0 = 0; i0 < nc; ++i0) {
1465614667
int i = i0 + i1*nc;
14657-
// ssm_state * dA + dB * x
14658-
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14668+
// state = prev_state * dA + dB * x
14669+
float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14670+
// y = rowwise_dotprod(state, C)
14671+
sumf += state*C[i0];
14672+
// FIXME: handle simultaneous sequences
14673+
s[i] = state;
1465914674
}
14675+
y[i1] = sumf;
1466014676
}
1466114677
}
1466214678

14663-
// compute state for rest of tokens, previous state comes from dest
14679+
// rest of the batch, state comes from previous one which was stored in destination
1466414680
for (int i2 = 1; i2 < n_t; ++i2) {
14665-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14666-
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14667-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14668-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
14669-
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14670-
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14681+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14682+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14683+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14684+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
14685+
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14686+
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14687+
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
1467114688
// d_inner
1467214689
for (int i1 = 0; i1 < ir; ++i1) {
1467314690
float dt_soft_plus = log1pf(expf(dt[i1]));
1467414691
float x_dt = x[i1] * dt_soft_plus;
14692+
float sumf = 0.0f;
1467514693
// d_state
1467614694
for (int i0 = 0; i0 < nc; ++i0) {
1467714695
int i = i0 + i1*nc;
14678-
// ssm_state * dA + dB * x
14679-
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14696+
// state = prev_state * dA + dB * x
14697+
float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14698+
// y = rowwise_dotprod(state, C)
14699+
sumf += state*C[i0];
14700+
// FIXME: handle simultaneous sequences
14701+
s[i] = state;
1468014702
}
14703+
y[i1] = sumf;
1468114704
}
1468214705
}
1468314706
}
@@ -14689,11 +14712,12 @@ static void ggml_compute_forward_ssm_scan(
1468914712
const struct ggml_tensor * src2,
1469014713
const struct ggml_tensor * src3,
1469114714
const struct ggml_tensor * src4,
14715+
const struct ggml_tensor * src5,
1469214716
struct ggml_tensor * dst) {
1469314717
switch (src0->type) {
1469414718
case GGML_TYPE_F32:
1469514719
{
14696-
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
14720+
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
1469714721
} break;
1469814722
default:
1469914723
{
@@ -15752,7 +15776,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1575215776
} break;
1575315777
case GGML_OP_SSM_SCAN:
1575415778
{
15755-
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
15779+
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor);
1575615780
} break;
1575715781
case GGML_OP_WIN_PART:
1575815782
{

ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,8 @@ extern "C" {
16961696
struct ggml_tensor * x,
16971697
struct ggml_tensor * dt,
16981698
struct ggml_tensor * A,
1699-
struct ggml_tensor * B);
1699+
struct ggml_tensor * B,
1700+
struct ggml_tensor * C);
17001701

17011702
// partition into non-overlapping windows with padding if needed
17021703
// example:

llama.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7803,9 +7803,9 @@ struct llm_build_context {
78037803
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
78047804
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
78057805

7806+
// clear states of sequences which are starting at the beginning of this batch
78067807
{
78077808
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
7808-
// clear states of sequences which are starting at the beginning of this batch
78097809
conv_states = ggml_mul(ctx0,
78107810
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
78117811
state_mask);
@@ -7814,11 +7814,8 @@ struct llm_build_context {
78147814
state_mask);
78157815
}
78167816

7817-
// TODO: support more than one sequence per batch (these could then use ggml_reshape_3d)
7818-
ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner,
7819-
(d_conv - 1)*ggml_element_size(conv_states), 0);
7820-
ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner,
7821-
(d_state)*ggml_element_size(ssm_states), 0);
7817+
struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
7818+
struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
78227819

78237820
// norm
78247821
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -7873,7 +7870,7 @@ struct llm_build_context {
78737870

78747871
// ssm
78757872
{
7876-
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
7873+
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
78777874
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
78787875
// split
78797876
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
@@ -7884,22 +7881,20 @@ struct llm_build_context {
78847881
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
78857882
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
78867883

7887-
// Custom operator to implement some of the optimizations
7888-
// described in the Annex D of the Mamba paper.
7889-
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
7890-
// => {d_state, d_inner, n_tokens}
7891-
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
7884+
// Custom operator to optimize the parallel associative scan
7885+
// as described in the Annex D of the Mamba paper.
7886+
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
7887+
// because only a single tensor can be returned.
7888+
struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C);
78927889

7893-
// only store last state
7890+
// store last states (the second part of y_ssm_states)
78947891
ggml_build_forward_expand(gf,
78957892
ggml_cpy(ctx0,
7896-
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]),
7897-
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
7893+
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
7894+
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
7895+
7896+
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
78987897

7899-
// {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens}
7900-
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
7901-
// => {d_inner, n_tokens}
7902-
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
79037898
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
79047899
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
79057900
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

0 commit comments

Comments
 (0)