Skip to content

Commit 6663128

Browse files
committed
kv-cache : rework kv_idxs, support seq_cp
ggml-ci
1 parent 0bb1da5 commit 6663128

File tree

5 files changed

+229
-98
lines changed

5 files changed

+229
-98
lines changed

examples/parallel/parallel.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,8 @@ int main(int argc, char ** argv) {
290290
for (int i = 1; i <= n_clients; ++i) {
291291
llama_memory_seq_rm(mem, i, -1, -1);
292292

293-
if (is_sp_shared) {
294-
// but keep the system prompt
295-
llama_memory_seq_cp(mem, 0, i, -1, -1);
296-
}
293+
// but keep the system prompt
294+
llama_memory_seq_cp(mem, 0, i, -1, -1);
297295
}
298296

299297
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -452,11 +450,8 @@ int main(int argc, char ** argv) {
452450
}
453451

454452
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
455-
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
456-
457-
if (is_sp_shared) {
458-
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
459-
}
453+
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
454+
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
460455

461456
const auto t_main_end = ggml_time_us();
462457

src/llama-graph.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284-
if (self_kv_idxs) {
285-
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
284+
if (self_k_idxs) {
285+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
286+
}
287+
288+
if (self_v_idxs) {
289+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
286290
}
287291

288292
if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291295
}
292296

293297
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294-
if (self_kv_idxs) {
295-
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
298+
if (self_k_idxs) {
299+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
300+
}
301+
302+
if (self_v_idxs) {
303+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
304+
}
305+
306+
if (self_k_idxs_swa) {
307+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
296308
}
297309

298-
if (self_kv_idxs_swa) {
299-
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
310+
if (self_v_idxs_swa) {
311+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
300312
}
301313

302314
if (self_kq_mask) {
@@ -1209,8 +1221,8 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12091221
const auto n_kv = mctx_cur->get_n_kv();
12101222
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
12111223

1212-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1213-
ggml_set_input(inp->self_kv_idxs);
1224+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1225+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12141226

12151227
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
12161228
ggml_set_input(inp->self_kq_mask);
@@ -1243,10 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn(
12431255

12441256
// store to KV cache
12451257
{
1246-
const auto & kv_idxs = inp->get_kv_idxs();
1258+
const auto & k_idxs = inp->get_k_idxs();
1259+
const auto & v_idxs = inp->get_v_idxs();
12471260

1248-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1249-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1261+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1262+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12501263
}
12511264

12521265
const auto & kq_mask = inp->get_kq_mask();
@@ -1299,10 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn(
12991312

13001313
// store to KV cache
13011314
{
1302-
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
1315+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1316+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
13031317

1304-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
1305-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
1318+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1319+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
13061320
}
13071321

13081322
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1444,8 +1458,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14441458
{
14451459
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14461460

1447-
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1448-
ggml_set_input(inp->self_kv_idxs);
1461+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1462+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14491463

14501464
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14511465
ggml_set_input(inp->self_kq_mask);
@@ -1458,8 +1472,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14581472

14591473
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14601474

1461-
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
1462-
ggml_set_input(inp->self_kv_idxs_swa);
1475+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1476+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14631477

14641478
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14651479
ggml_set_input(inp->self_kq_mask_swa);

src/llama-graph.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,13 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
248248

249249
void set_input(const llama_ubatch * ubatch) override;
250250

251-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
251+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
252+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
253+
252254
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253255

254-
// TODO: should this be I64?
255-
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
256+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
257+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
256258

257259
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
258260
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
@@ -277,13 +279,18 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
277279

278280
void set_input(const llama_ubatch * ubatch) override;
279281

280-
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
281-
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
282+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
283+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
284+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
285+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
286+
282287
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
283288
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
284289

285-
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
286-
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
290+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
291+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
292+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
293+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
287294

288295
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
289296
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]

0 commit comments

Comments
 (0)