@@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context {
9024
9024
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
9025
9025
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
9026
9026
9027
- // (ab)using the KV cache to store the states
9028
9027
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9029
9028
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9030
- ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9031
- ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
9032
9029
9033
9030
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
9034
9031
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context {
9094
9091
cur = x;
9095
9092
x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
9096
9093
9097
- ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
9098
- // Custom operator to optimize the parallel associative scan
9099
- // as described in the Annex D of the Mamba paper.
9100
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9101
- ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids);
9094
+ ggml_tensor * A = model.layers[il].ssm_a;
9095
+
9096
+ // use the states and the indices provided by build_recurrent_state
9097
+ // (this is necessary in order to properly use the states before they are overwritten,
9098
+ // while avoiding to make unnecessary copies of the states)
9099
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
9100
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
9101
+
9102
+ // Custom operator to optimize the parallel associative scan
9103
+ // as described in the Annex D of the Mamba paper.
9104
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9105
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
9106
+ };
9107
+
9108
+ ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
9102
9109
9103
9110
// store last states
9104
9111
ggml_build_forward_expand(gf,
@@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context {
9151
9158
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
9152
9159
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
9153
9160
9154
- // (ab)using the KV cache to store the states
9155
9161
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9156
9162
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9157
- ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9158
- ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
9159
9163
9160
9164
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
9161
9165
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context {
9211
9215
// {n_head, n_seq_tokens, n_seqs}
9212
9216
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
9213
9217
9214
- ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
9215
- // TODO: use semistructured matrices to implement state-space duality
9216
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9217
- ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids);
9218
+ ggml_tensor * A = model.layers[il].ssm_a;
9219
+
9220
+ // use the states and the indices provided by build_recurrent_state
9221
+ // (this is necessary in order to properly use the states before they are overwritten,
9222
+ // while avoiding to make unnecessary copies of the states)
9223
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
9224
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
9225
+
9226
+ // TODO: use semistructured matrices to implement state-space duality
9227
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
9228
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
9229
+ };
9230
+
9231
+ ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
9218
9232
9219
9233
// store last states
9220
9234
ggml_build_forward_expand(gf,
0 commit comments