@@ -1782,6 +1782,7 @@ struct llama_layer {
1782
1782
struct llama_kv_cell {
1783
1783
llama_pos pos = -1;
1784
1784
llama_pos delta = 0;
1785
+ int32_t src = 0; // used by recurrent state models to copy states
1785
1786
1786
1787
std::set<llama_seq_id> seq_id;
1787
1788
@@ -1802,6 +1803,7 @@ struct llama_kv_cell {
1802
1803
struct llama_kv_cache {
1803
1804
bool has_shift = false;
1804
1805
bool do_defrag = false;
1806
+ bool do_copy = false;
1805
1807
// with Mamba, a cell can hold the state for more than one past token
1806
1808
bool unlimited = false;
1807
1809
@@ -2040,7 +2042,8 @@ struct llama_context {
2040
2042
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2041
2043
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2042
2044
struct ggml_tensor * inp_cls; // I32 [n_batch]
2043
- struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
2045
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2046
+ struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2044
2047
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2045
2048
2046
2049
#ifdef GGML_USE_MPI
@@ -2082,9 +2085,9 @@ static bool llama_kv_cache_init(
2082
2085
2083
2086
if (cache.unlimited) {
2084
2087
for (uint32_t i = 0; i < cache.size; ++i) {
2085
- cache.cells[i].delta = i;
2088
+ cache.cells[i].src = i;
2086
2089
}
2087
- } // else, delta is already initialized to zero
2090
+ }
2088
2091
2089
2092
#ifdef GGML_USE_CLBLAST
2090
2093
offload = false;
@@ -2335,19 +2338,20 @@ static void llama_kv_cache_seq_cp(
2335
2338
2336
2339
if (cache.unlimited) {
2337
2340
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2338
- seq_id_src = cache.cells[seq_id_src].delta ;
2341
+ seq_id_src = cache.cells[seq_id_src].src ;
2339
2342
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2340
2343
// intent to "copy from"
2341
2344
// supports copy chains thanks to taking the source of the source
2342
- cache.cells[seq_id_dst].delta = seq_id_src;
2345
+ cache.cells[seq_id_dst].src = seq_id_src;
2343
2346
2344
- // prevent the destination from getting cleared if the source is not empty
2347
+ // preserve the "keep or clear" status of the copied sequence
2345
2348
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2346
2349
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2350
+ } else {
2351
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
2347
2352
}
2348
- // repurposed as a "need copy" flag
2349
- // (shifting can't be done anyway for this kind of KV cache)
2350
- cache.has_shift = true;
2353
+
2354
+ cache.do_copy = true;
2351
2355
2352
2356
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
2353
2357
}
@@ -5436,21 +5440,7 @@ struct llm_build_context {
5436
5440
struct ggml_cgraph * build_k_shift() {
5437
5441
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5438
5442
5439
- // TODO: do this in a another graph with a dedicated input tensor
5440
- if (kv_self.unlimited) {
5441
- for (int il = 0; il < n_layer; ++il) {
5442
- ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5443
- ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5444
-
5445
- conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
5446
- ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
5447
-
5448
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5449
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5450
- }
5451
-
5452
- return gf;
5453
- }
5443
+ GGML_ASSERT(kv_self.size == n_ctx);
5454
5444
5455
5445
for (int il = 0; il < n_layer; ++il) {
5456
5446
struct ggml_tensor * tmp =
@@ -5470,6 +5460,25 @@ struct llm_build_context {
5470
5460
return gf;
5471
5461
}
5472
5462
5463
+ struct ggml_cgraph * build_s_copy() {
5464
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5465
+
5466
+ for (int il = 0; il < n_layer; ++il) {
5467
+ ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5468
+ ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5469
+
5470
+ conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5471
+ ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5472
+
5473
+ // TODO: name the intermediate tensors with cb()
5474
+
5475
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5476
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5477
+ }
5478
+
5479
+ return gf;
5480
+ }
5481
+
5473
5482
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
5474
5483
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5475
5484
@@ -8202,6 +8211,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
8202
8211
return result;
8203
8212
}
8204
8213
8214
+ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
8215
+ llama_batch dummy;
8216
+ dummy.n_tokens = 0;
8217
+
8218
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
8219
+
8220
+ struct llm_build_context llm(lctx, dummy, cb, false);
8221
+
8222
+ llm.init();
8223
+
8224
+ struct ggml_cgraph * result = llm.build_s_copy();
8225
+
8226
+ llm.free();
8227
+
8228
+ return result;
8229
+ }
8230
+
8205
8231
static struct ggml_cgraph * llama_build_graph(
8206
8232
llama_context & lctx,
8207
8233
const llama_batch & batch,
@@ -8341,6 +8367,18 @@ static void llama_set_k_shift(llama_context & lctx) {
8341
8367
}
8342
8368
}
8343
8369
8370
+ static void llama_set_s_copy(llama_context & lctx) {
8371
+ const int64_t kv_size = lctx.kv_self.size;
8372
+
8373
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8374
+
8375
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8376
+
8377
+ for (int i = 0; i < kv_size; ++i) {
8378
+ data[i] = lctx.kv_self.cells[i].src;
8379
+ }
8380
+ }
8381
+
8344
8382
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8345
8383
//
8346
8384
// set input data
@@ -8455,17 +8493,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8455
8493
}
8456
8494
8457
8495
if (kv_self.unlimited) {
8458
- const int64_t n_kv = kv_self.n;
8496
+ const int64_t n_kv = kv_self.n;
8459
8497
8460
8498
{
8461
8499
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
8462
8500
float * data = (float *) lctx.inp_s_mask->data;
8463
8501
8464
8502
// states which are not affected by the current batch are left untouched
8465
8503
for (int i = 0; i < n_kv; ++i) {
8466
- llama_seq_id seq_id = i + lctx.kv_self.head;
8467
- llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8468
- bool has_self_seq = kv_cell.has_seq_id(seq_id);
8504
+ llama_seq_id seq_id = i + lctx.kv_self.head;
8505
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8506
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
8469
8507
8470
8508
data[i] = (float) has_self_seq;
8471
8509
@@ -8988,7 +9026,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8988
9026
8989
9027
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8990
9028
// apply K-shift if needed
8991
- if (( lctx.kv_self.unlimited || lctx. model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
9029
+ if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
8992
9030
llama_set_k_shift(lctx);
8993
9031
8994
9032
{
@@ -9003,7 +9041,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
9003
9041
kv_self.has_shift = false;
9004
9042
9005
9043
for (uint32_t i = 0; i < kv_self.size; ++i) {
9006
- kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
9044
+ kv_self.cells[i].delta = 0;
9045
+ }
9046
+ }
9047
+ }
9048
+
9049
+ if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
9050
+ llama_set_s_copy(lctx);
9051
+
9052
+ {
9053
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
9054
+
9055
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9056
+ }
9057
+
9058
+ {
9059
+ auto & kv_self = lctx.kv_self;
9060
+
9061
+ kv_self.do_copy = false;
9062
+
9063
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
9064
+ kv_self.cells[i].src = i;
9007
9065
}
9008
9066
}
9009
9067
}
@@ -12644,7 +12702,7 @@ struct llama_context * llama_new_context_with_model(
12644
12702
// graph inputs
12645
12703
{
12646
12704
ggml_init_params init_params = {
12647
- /* .mem_size */ ggml_tensor_overhead()*(8 + 2 *(ctx->kv_self.unlimited)),
12705
+ /* .mem_size */ ggml_tensor_overhead()*(8 + 3 *(ctx->kv_self.unlimited)),
12648
12706
/* .mem_buffer */ nullptr,
12649
12707
/* .no_alloc */ true,
12650
12708
};
@@ -12659,6 +12717,7 @@ struct llama_context * llama_new_context_with_model(
12659
12717
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
12660
12718
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
12661
12719
if (ctx->kv_self.unlimited) {
12720
+ ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
12662
12721
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
12663
12722
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
12664
12723
}
@@ -12672,6 +12731,7 @@ struct llama_context * llama_new_context_with_model(
12672
12731
ggml_set_name(ctx->inp_mean, "inp_mean");
12673
12732
ggml_set_name(ctx->inp_cls, "inp_cls");
12674
12733
if (ctx->kv_self.unlimited) {
12734
+ ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
12675
12735
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
12676
12736
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
12677
12737
}
0 commit comments