Skip to content

Commit 2fa5f2c

Browse files
committed
graph : fix recurrent state copies when avoiding copies
Works, but using lambda functions might not be that clean.
1 parent 9864bfc commit 2fa5f2c

File tree

3 files changed

+38
-28
lines changed

3 files changed

+38
-28
lines changed

src/llama-graph.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14291429
ggml_tensor * state_copy,
14301430
int32_t state_size,
14311431
int32_t n_seqs,
1432-
bool avoid_copies) const {
1432+
const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)> & get_state_rows) const {
1433+
14331434
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
14341435

14351436
const auto n_kv = kv_state->get_n_kv();
@@ -1445,17 +1446,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14451446

14461447
ggml_tensor * output_states;
14471448

1448-
if (!avoid_copies) {
1449-
// copy states
1450-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1451-
// {state_size, kv_size} -> {state_size, n_seqs}
1452-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1453-
ggml_build_forward_expand(gf, output_states);
1454-
} else {
1455-
// FIXME: make the gathering operation happen before the copy below
1456-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1457-
output_states = states;
1458-
}
1449+
// copy states
1450+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1451+
// {state_size, kv_size} -> {state_size, n_seqs}
1452+
output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1453+
ggml_build_forward_expand(gf, output_states);
14591454

14601455
// copy extra states which won't be changed further (between n_seqs and n_kv)
14611456
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));

src/llama-graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,8 @@ struct llm_graph_context {
599599
ggml_tensor * state_copy,
600600
int32_t state_size,
601601
int32_t n_seqs,
602-
bool avoid_copies = false) const;
602+
const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>
603+
& get_state_rows = ggml_get_rows) const;
603604

604605
ggml_tensor * build_rwkv_token_shift_load(
605606
ggml_cgraph * gf,

src/llama-model.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context {
90249024
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
90259025
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
90269026

9027-
// (ab)using the KV cache to store the states
90289027
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
90299028
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9030-
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9031-
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
90329029

90339030
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
90349031
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context {
90949091
cur = x;
90959092
x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
90969093

9097-
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
9098-
// Custom operator to optimize the parallel associative scan
9099-
// as described in the Annex D of the Mamba paper.
9100-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9101-
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids);
9094+
ggml_tensor * A = model.layers[il].ssm_a;
9095+
9096+
// use the states and the indices provided by build_recurrent_state
9097+
// (this is necessary in order to properly use the states before they are overwritten,
9098+
// while avoiding to make unnecessary copies of the states)
9099+
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
9100+
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
9101+
9102+
// Custom operator to optimize the parallel associative scan
9103+
// as described in the Annex D of the Mamba paper.
9104+
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9105+
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
9106+
};
9107+
9108+
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
91029109

91039110
// store last states
91049111
ggml_build_forward_expand(gf,
@@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context {
91519158
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
91529159
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91539160

9154-
// (ab)using the KV cache to store the states
91559161
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
91569162
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9157-
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9158-
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91599163

91609164
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
91619165
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context {
92119215
// {n_head, n_seq_tokens, n_seqs}
92129216
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
92139217

9214-
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
9215-
// TODO: use semistructured matrices to implement state-space duality
9216-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9217-
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids);
9218+
ggml_tensor * A = model.layers[il].ssm_a;
9219+
9220+
// use the states and the indices provided by build_recurrent_state
9221+
// (this is necessary in order to properly use the states before they are overwritten,
9222+
// while avoiding to make unnecessary copies of the states)
9223+
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
9224+
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
9225+
9226+
// TODO: use semistructured matrices to implement state-space duality
9227+
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9228+
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
9229+
};
9230+
9231+
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
92189232

92199233
// store last states
92209234
ggml_build_forward_expand(gf,

0 commit comments

Comments
 (0)