Skip to content

Commit 4043fb8

Browse files
committed
mamba : reduce memory usage of ggml_ssm_scan
From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built.
1 parent 1b8b211 commit 4043fb8

File tree

3 files changed

+70
-50
lines changed

3 files changed

+70
-50
lines changed

ggml.c

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6087,14 +6087,15 @@ struct ggml_tensor * ggml_ssm_scan(
60876087
struct ggml_tensor * x,
60886088
struct ggml_tensor * dt,
60896089
struct ggml_tensor * A,
6090-
struct ggml_tensor * B) {
6090+
struct ggml_tensor * B,
6091+
struct ggml_tensor * C) {
60916092
GGML_ASSERT(ggml_is_contiguous(s));
60926093
GGML_ASSERT(ggml_is_contiguous(x));
60936094
GGML_ASSERT(ggml_is_contiguous(dt));
60946095
GGML_ASSERT(ggml_is_contiguous(A));
60956096
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6097+
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
60966098
GGML_ASSERT(ggml_are_same_shape(x, dt));
6097-
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
60986099

60996100
{
61006101
const int64_t d_state = s->ne[0];
@@ -6106,6 +6107,8 @@ struct ggml_tensor * ggml_ssm_scan(
61066107
GGML_ASSERT(A->ne[1] == d_inner);
61076108
GGML_ASSERT(B->ne[0] == d_state);
61086109
GGML_ASSERT(B->ne[1] == n_tokens);
6110+
GGML_ASSERT(C->ne[0] == d_state);
6111+
GGML_ASSERT(C->ne[1] == n_tokens);
61096112
}
61106113

61116114
bool is_node = false;
@@ -6115,7 +6118,8 @@ struct ggml_tensor * ggml_ssm_scan(
61156118
is_node = true;
61166119
}
61176120

6118-
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);
6121+
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6122+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
61196123

61206124
result->op = GGML_OP_SSM_SCAN;
61216125
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6124,6 +6128,7 @@ struct ggml_tensor * ggml_ssm_scan(
61246128
result->src[2] = dt;
61256129
result->src[3] = A;
61266130
result->src[4] = B;
6131+
result->src[5] = C;
61276132

61286133
return result;
61296134
}
@@ -14650,6 +14655,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1465014655
const struct ggml_tensor * src2, // dt
1465114656
const struct ggml_tensor * src3, // A
1465214657
const struct ggml_tensor * src4, // B
14658+
const struct ggml_tensor * src5, // C
1465314659
struct ggml_tensor * dst) {
1465414660
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
1465514661
return;
@@ -14658,67 +14664,84 @@ static void ggml_compute_forward_ssm_scan_f32(
1465814664
const int ith = params->ith;
1465914665
const int nth = params->nth;
1466014666

14661-
const int64_t nc = src0->ne[0];
14667+
const int64_t nc = src0->ne[0]; // d_state
14668+
const int64_t nr = src0->ne[1]; // d_inner
1466214669
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
14663-
const int64_t nr0 = ggml_nrows(src0);
1466414670

14665-
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
14671+
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
1466614672
GGML_ASSERT(src0->nb[0] == sizeof(float));
1466714673
GGML_ASSERT(src1->nb[0] == sizeof(float));
1466814674
GGML_ASSERT(src2->nb[0] == sizeof(float));
1466914675
GGML_ASSERT(src3->nb[0] == sizeof(float));
1467014676
GGML_ASSERT(src4->nb[0] == sizeof(float));
14671-
// allow merging multiple rows in the same vec operation
14677+
GGML_ASSERT(src5->nb[0] == sizeof(float));
14678+
// required for the dot product between s and C
1467214679
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14673-
GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));
14680+
// required to get correct offset for state destination
14681+
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
1467414682

1467514683
// rows per thread
14676-
const int dr = (nr0 + nth - 1)/nth;
14684+
const int dr = (nr + nth - 1)/nth;
1467714685

1467814686
// row range for this thread
1467914687
const int ir0 = dr*ith;
14680-
const int ir1 = MIN(ir0 + dr, nr0);
14688+
const int ir1 = MIN(ir0 + dr, nr);
1468114689
const int ir = ir1 - ir0;
1468214690

14683-
// first batch
14691+
// first token in the batch
1468414692
{
14685-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
14686-
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14687-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14688-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14689-
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14690-
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14693+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14694+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14695+
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner, n_kv}
14696+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14697+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
14698+
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14699+
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
14700+
float * C = (float *) ((char *) src5->data); // {d_state, n_tokens}
1469114701
// d_inner
1469214702
for (int i1 = 0; i1 < ir; ++i1) {
1469314703
float dt_soft_plus = log1pf(expf(dt[i1]));
1469414704
float x_dt = x[i1] * dt_soft_plus;
14705+
float sumf = 0.0f;
1469514706
// d_state
1469614707
for (int i0 = 0; i0 < nc; ++i0) {
1469714708
int i = i0 + i1*nc;
14698-
// ssm_state * dA + dB * x
14699-
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14709+
// state = prev_state * dA + dB * x
14710+
float state = s0[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14711+
// y = rowwise_dotprod(state, C)
14712+
sumf += state*C[i0];
14713+
// FIXME: handle simultaneous sequences
14714+
s[i] = state;
1470014715
}
14716+
y[i1] = sumf;
1470114717
}
1470214718
}
1470314719

14704-
// compute state for rest of tokens, previous state comes from dest
14720+
// rest of the batch, state comes from previous one which was stored in destination
1470514721
for (int i2 = 1; i2 < n_t; ++i2) {
14706-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14707-
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14708-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14709-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
14710-
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14711-
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14722+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14723+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + src1->nb[2]); // {d_state, d_inner, n_kv}
14724+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14725+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
14726+
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14727+
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
14728+
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
1471214729
// d_inner
1471314730
for (int i1 = 0; i1 < ir; ++i1) {
1471414731
float dt_soft_plus = log1pf(expf(dt[i1]));
1471514732
float x_dt = x[i1] * dt_soft_plus;
14733+
float sumf = 0.0f;
1471614734
// d_state
1471714735
for (int i0 = 0; i0 < nc; ++i0) {
1471814736
int i = i0 + i1*nc;
14719-
// ssm_state * dA + dB * x
14720-
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14737+
// state = prev_state * dA + dB * x
14738+
float state = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14739+
// y = rowwise_dotprod(state, C)
14740+
sumf += state*C[i0];
14741+
// FIXME: handle simultaneous sequences
14742+
s[i] = state;
1472114743
}
14744+
y[i1] = sumf;
1472214745
}
1472314746
}
1472414747
}
@@ -14730,11 +14753,12 @@ static void ggml_compute_forward_ssm_scan(
1473014753
const struct ggml_tensor * src2,
1473114754
const struct ggml_tensor * src3,
1473214755
const struct ggml_tensor * src4,
14756+
const struct ggml_tensor * src5,
1473314757
struct ggml_tensor * dst) {
1473414758
switch (src0->type) {
1473514759
case GGML_TYPE_F32:
1473614760
{
14737-
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
14761+
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, dst);
1473814762
} break;
1473914763
default:
1474014764
{
@@ -15796,7 +15820,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1579615820
} break;
1579715821
case GGML_OP_SSM_SCAN:
1579815822
{
15799-
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
15823+
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor);
1580015824
} break;
1580115825
case GGML_OP_WIN_PART:
1580215826
{

ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,8 @@ extern "C" {
17081708
struct ggml_tensor * x,
17091709
struct ggml_tensor * dt,
17101710
struct ggml_tensor * A,
1711-
struct ggml_tensor * B);
1711+
struct ggml_tensor * B,
1712+
struct ggml_tensor * C);
17121713

17131714
// partition into non-overlapping windows with padding if needed
17141715
// example:

llama.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8020,9 +8020,9 @@ struct llm_build_context {
80208020
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
80218021
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
80228022

8023+
// clear states of sequences which are starting at the beginning of this batch
80238024
{
80248025
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
8025-
// clear states of sequences which are starting at the beginning of this batch
80268026
conv_states = ggml_mul(ctx0,
80278027
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
80288028
state_mask);
@@ -8031,11 +8031,8 @@ struct llm_build_context {
80318031
state_mask);
80328032
}
80338033

8034-
// TODO: support more than one sequence per batch (these could then use ggml_reshape_3d)
8035-
ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner,
8036-
(d_conv - 1)*ggml_element_size(conv_states), 0);
8037-
ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner,
8038-
(d_state)*ggml_element_size(ssm_states), 0);
8034+
struct ggml_tensor * conv_state = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
8035+
struct ggml_tensor * ssm_state = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
80398036

80408037
// norm
80418038
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8090,7 +8087,7 @@ struct llm_build_context {
80908087

80918088
// ssm
80928089
{
8093-
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
8090+
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
80948091
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
80958092
// split
80968093
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
@@ -8101,22 +8098,20 @@ struct llm_build_context {
81018098
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
81028099
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
81038100

8104-
// Custom operator to implement some of the optimizations
8105-
// described in the Annex D of the Mamba paper.
8106-
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
8107-
// => {d_state, d_inner, n_tokens}
8108-
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);
8101+
// Custom operator to optimize the parallel associative scan
8102+
// as described in the Annex D of the Mamba paper.
8103+
// => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
8104+
// because only a single tensor can be returned.
8105+
struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B, C);
81098106

8110-
// only store last state
8107+
// store last states (the second part of y_ssm_states)
81118108
ggml_build_forward_expand(gf,
81128109
ggml_cpy(ctx0,
8113-
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]),
8114-
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
8110+
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
8111+
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state))));
8112+
8113+
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
81158114

8116-
// {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens}
8117-
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
8118-
// => {d_inner, n_tokens}
8119-
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
81208115
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
81218116
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
81228117
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

0 commit comments

Comments
 (0)