Skip to content

Commit 5acf897

Browse files
committed
mamba : multiple sequences, but one at a time
This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok
1 parent 5db47b8 commit 5acf897

File tree

4 files changed

+259
-91
lines changed

4 files changed

+259
-91
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12831283

12841284
cparams.n_ctx = params.n_ctx;
12851285
cparams.n_batch = params.n_batch;
1286+
cparams.n_parallel = params.n_parallel;
12861287
cparams.n_threads = params.n_threads;
12871288
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
12881289
cparams.seed = params.seed;

ggml.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6097,15 +6097,15 @@ struct ggml_tensor * ggml_ssm_scan(
60976097
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
60986098

60996099
{
6100-
const int64_t d_state = s->ne[0];
6101-
const int64_t d_inner = s->ne[1];
6102-
const int64_t n_tok = x->ne[1];
6100+
const int64_t d_state = s->ne[0];
6101+
const int64_t d_inner = s->ne[1];
6102+
const int64_t n_tokens = x->ne[1];
61036103

61046104
GGML_ASSERT(x->ne[0] == d_inner);
61056105
GGML_ASSERT(A->ne[0] == d_state);
61066106
GGML_ASSERT(A->ne[1] == d_inner);
61076107
GGML_ASSERT(B->ne[0] == d_state);
6108-
GGML_ASSERT(B->ne[1] == n_tok);
6108+
GGML_ASSERT(B->ne[1] == n_tokens);
61096109
}
61106110

61116111
bool is_node = false;
@@ -14682,12 +14682,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1468214682

1468314683
// first batch
1468414684
{
14685-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14685+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
1468614686
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14687-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14688-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
14687+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14688+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
1468914689
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14690-
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
14690+
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
1469114691
// d_inner
1469214692
for (int i1 = 0; i1 < ir; ++i1) {
1469314693
float dt_soft_plus = log1pf(expf(dt[i1]));
@@ -14703,12 +14703,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1470314703

1470414704
// compute state for rest of tokens, previous state comes from dest
1470514705
for (int i2 = 1; i2 < n_t; ++i2) {
14706-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14707-
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14708-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14709-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
14706+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14707+
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14708+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14709+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
1471014710
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14711-
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
14711+
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
1471214712
// d_inner
1471314713
for (int i1 = 0; i1 < ir; ++i1) {
1471414714
float dt_soft_plus = log1pf(expf(dt[i1]));

0 commit comments

Comments
 (0)