Skip to content

Commit 7fa8d6b

Browse files
committed
mamba : dedicate an input tensor for state copy indices
This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers.
1 parent 3d48f97 commit 7fa8d6b

File tree

1 file changed

+91
-31
lines changed

1 file changed

+91
-31
lines changed

llama.cpp

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,7 @@ struct llama_layer {
17821782
struct llama_kv_cell {
17831783
llama_pos pos = -1;
17841784
llama_pos delta = 0;
1785+
int32_t src = 0; // used by recurrent state models to copy states
17851786

17861787
std::set<llama_seq_id> seq_id;
17871788

@@ -1802,6 +1803,7 @@ struct llama_kv_cell {
18021803
struct llama_kv_cache {
18031804
bool has_shift = false;
18041805
bool do_defrag = false;
1806+
bool do_copy = false;
18051807
// with Mamba, a cell can hold the state for more than one past token
18061808
bool unlimited = false;
18071809

@@ -2040,7 +2042,8 @@ struct llama_context {
20402042
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
20412043
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
20422044
struct ggml_tensor * inp_cls; // I32 [n_batch]
2043-
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
2045+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2046+
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
20442047
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
20452048

20462049
#ifdef GGML_USE_MPI
@@ -2082,9 +2085,9 @@ static bool llama_kv_cache_init(
20822085

20832086
if (cache.unlimited) {
20842087
for (uint32_t i = 0; i < cache.size; ++i) {
2085-
cache.cells[i].delta = i;
2088+
cache.cells[i].src = i;
20862089
}
2087-
} // else, delta is already initialized to zero
2090+
}
20882091

20892092
#ifdef GGML_USE_CLBLAST
20902093
offload = false;
@@ -2335,19 +2338,20 @@ static void llama_kv_cache_seq_cp(
23352338

23362339
if (cache.unlimited) {
23372340
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2338-
seq_id_src = cache.cells[seq_id_src].delta;
2341+
seq_id_src = cache.cells[seq_id_src].src;
23392342
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
23402343
// intent to "copy from"
23412344
// supports copy chains thanks to taking the source of the source
2342-
cache.cells[seq_id_dst].delta = seq_id_src;
2345+
cache.cells[seq_id_dst].src = seq_id_src;
23432346

2344-
// prevent the destination from getting cleared if the source is not empty
2347+
// preserve the "keep or clear" status of the copied sequence
23452348
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
23462349
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2350+
} else {
2351+
cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
23472352
}
2348-
// repurposed as a "need copy" flag
2349-
// (shifting can't be done anyway for this kind of KV cache)
2350-
cache.has_shift = true;
2353+
2354+
cache.do_copy = true;
23512355

23522356
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
23532357
}
@@ -5436,21 +5440,7 @@ struct llm_build_context {
54365440
struct ggml_cgraph * build_k_shift() {
54375441
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
54385442

5439-
// TODO: do this in a another graph with a dedicated input tensor
5440-
if (kv_self.unlimited) {
5441-
for (int il = 0; il < n_layer; ++il) {
5442-
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5443-
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5444-
5445-
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
5446-
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
5447-
5448-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5449-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5450-
}
5451-
5452-
return gf;
5453-
}
5443+
GGML_ASSERT(kv_self.size == n_ctx);
54545444

54555445
for (int il = 0; il < n_layer; ++il) {
54565446
struct ggml_tensor * tmp =
@@ -5470,6 +5460,25 @@ struct llm_build_context {
54705460
return gf;
54715461
}
54725462

5463+
struct ggml_cgraph * build_s_copy() {
5464+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5465+
5466+
for (int il = 0; il < n_layer; ++il) {
5467+
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5468+
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5469+
5470+
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5471+
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5472+
5473+
// TODO: name the intermediate tensors with cb()
5474+
5475+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5476+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5477+
}
5478+
5479+
return gf;
5480+
}
5481+
54735482
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
54745483
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
54755484

@@ -8202,6 +8211,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
82028211
return result;
82038212
}
82048213

8214+
static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
8215+
llama_batch dummy;
8216+
dummy.n_tokens = 0;
8217+
8218+
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
8219+
8220+
struct llm_build_context llm(lctx, dummy, cb, false);
8221+
8222+
llm.init();
8223+
8224+
struct ggml_cgraph * result = llm.build_s_copy();
8225+
8226+
llm.free();
8227+
8228+
return result;
8229+
}
8230+
82058231
static struct ggml_cgraph * llama_build_graph(
82068232
llama_context & lctx,
82078233
const llama_batch & batch,
@@ -8341,6 +8367,18 @@ static void llama_set_k_shift(llama_context & lctx) {
83418367
}
83428368
}
83438369

8370+
static void llama_set_s_copy(llama_context & lctx) {
8371+
const int64_t kv_size = lctx.kv_self.size;
8372+
8373+
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8374+
8375+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8376+
8377+
for (int i = 0; i < kv_size; ++i) {
8378+
data[i] = lctx.kv_self.cells[i].src;
8379+
}
8380+
}
8381+
83448382
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
83458383
//
83468384
// set input data
@@ -8455,17 +8493,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
84558493
}
84568494

84578495
if (kv_self.unlimited) {
8458-
const int64_t n_kv = kv_self.n;
8496+
const int64_t n_kv = kv_self.n;
84598497

84608498
{
84618499
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
84628500
float * data = (float *) lctx.inp_s_mask->data;
84638501

84648502
// states which are not affected by the current batch are left untouched
84658503
for (int i = 0; i < n_kv; ++i) {
8466-
llama_seq_id seq_id = i + lctx.kv_self.head;
8467-
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8468-
bool has_self_seq = kv_cell.has_seq_id(seq_id);
8504+
llama_seq_id seq_id = i + lctx.kv_self.head;
8505+
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8506+
bool has_self_seq = kv_cell.has_seq_id(seq_id);
84698507

84708508
data[i] = (float) has_self_seq;
84718509

@@ -8988,7 +9026,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
89889026

89899027
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
89909028
// apply K-shift if needed
8991-
if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
9029+
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
89929030
llama_set_k_shift(lctx);
89939031

89949032
{
@@ -9003,7 +9041,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
90039041
kv_self.has_shift = false;
90049042

90059043
for (uint32_t i = 0; i < kv_self.size; ++i) {
9006-
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
9044+
kv_self.cells[i].delta = 0;
9045+
}
9046+
}
9047+
}
9048+
9049+
if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
9050+
llama_set_s_copy(lctx);
9051+
9052+
{
9053+
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
9054+
9055+
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9056+
}
9057+
9058+
{
9059+
auto & kv_self = lctx.kv_self;
9060+
9061+
kv_self.do_copy = false;
9062+
9063+
for (uint32_t i = 0; i < kv_self.size; ++i) {
9064+
kv_self.cells[i].src = i;
90079065
}
90089066
}
90099067
}
@@ -12644,7 +12702,7 @@ struct llama_context * llama_new_context_with_model(
1264412702
// graph inputs
1264512703
{
1264612704
ggml_init_params init_params = {
12647-
/* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)),
12705+
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)),
1264812706
/* .mem_buffer */ nullptr,
1264912707
/* .no_alloc */ true,
1265012708
};
@@ -12659,6 +12717,7 @@ struct llama_context * llama_new_context_with_model(
1265912717
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
1266012718
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1266112719
if (ctx->kv_self.unlimited) {
12720+
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
1266212721
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
1266312722
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
1266412723
}
@@ -12672,6 +12731,7 @@ struct llama_context * llama_new_context_with_model(
1267212731
ggml_set_name(ctx->inp_mean, "inp_mean");
1267312732
ggml_set_name(ctx->inp_cls, "inp_cls");
1267412733
if (ctx->kv_self.unlimited) {
12734+
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
1267512735
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
1267612736
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
1267712737
}

0 commit comments

Comments
 (0)