Skip to content

Commit 0bb1da5

Browse files
committed
kv-cache : simplify set_rows logic
ggml-ci
1 parent 165d822 commit 0bb1da5

File tree

1 file changed

+18
-45
lines changed

1 file changed

+18
-45
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -937,17 +937,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
937937
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
938938
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
939939
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
940-
size_virt,
940+
size_virt, // v->nb[3]
941941
size_virt*sinfo.s0);
942942
}
943943

944944
// note: v->nb[1] > v->nb[2]
945945
return ggml_view_4d(ctx, v,
946946
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
947-
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
948-
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
949-
size_virt,
950-
size_virt*sinfo.s0);
947+
ggml_row_size(v->type, v->ne[1]*n_seq_virt*hparams.n_embd_head_v), // v->nb[1]
948+
ggml_row_size(v->type, v->ne[1]*n_seq_virt), // v->nb[2]
949+
ggml_row_size(v->type, v->ne[1]), // v->nb[3]
950+
ggml_row_size(v->type, v->ne[1]*sinfo.s0));
951951
}
952952

953953
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
@@ -961,20 +961,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
961961
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
962962

963963
if (kv_idxs && supports_set_rows) {
964-
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
965-
966-
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
967-
968-
ggml_tensor * k_view = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], ns,
969-
ggml_row_size(k->type, k->ne[0]),
970-
size_virt,
971-
size_virt*sinfo.s0);
972-
973-
k_cur = ggml_reshape_3d(ctx, k_cur, k_cur->ne[0], k_cur->ne[1]/ns, ns);
974-
975-
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
964+
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
976965

977-
return ggml_set_rows(ctx, k_view, k_cur, kv_idxs);
966+
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
978967
}
979968

980969
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -1000,45 +989,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
1000989
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1001990

1002991
if (kv_idxs && supports_set_rows) {
1003-
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1004-
1005-
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
1006-
1007992
if (!v_trans) {
1008-
ggml_tensor * v_view = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], ns,
1009-
ggml_row_size(v->type, v->ne[0]),
1010-
size_virt,
1011-
size_virt*sinfo.s0);
1012-
1013-
v_cur = ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], v_cur->ne[1]/ns, ns);
1014-
1015-
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
993+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1016994

1017-
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
995+
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
1018996
}
1019997

1020998
// the row becomes a single element
1021-
ggml_tensor * v_view = ggml_view_4d(ctx, v, 1, v->ne[1], v->ne[0], ns,
1022-
ggml_row_size(v->type, 1),
1023-
ggml_row_size(v->type, v->ne[1]),
1024-
size_virt,
1025-
size_virt*sinfo.s0);
999+
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]);
10261000

10271001
// note: the V cache is transposed when not using flash attention
1028-
v_cur = ggml_permute(ctx, ggml_reshape_4d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]/ns, ns), 2, 0, 1, 3);
1002+
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
10291003

10301004
// note: we can be more explicit here at the cost of extra cont
10311005
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032-
//v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
10331006
//v_cur = ggml_transpose(ctx, v_cur);
1034-
//v_cur = ggml_cont_4d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2]);
1007+
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
10351008

10361009
// we broadcast the KV indices n_embd_v_gqa times
1037-
// v [1, n_kv, n_embd_v_gqa, ns]
1038-
// v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039-
// kv_idxs [n_tokens/ns, 1, ns]
1040-
1041-
kv_idxs = ggml_reshape_3d(ctx, kv_idxs, n_tokens/ns, 1, ns);
1010+
// v [1, n_kv*n_seq_virt, n_embd_v_gqa]
1011+
// v_cur [1, n_tokens, n_embd_v_gqa]
1012+
// kv_idxs [n_tokens, 1, 1]
10421013

10431014
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
10441015
}
@@ -1077,8 +1048,10 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
10771048
int64_t * data = (int64_t *) dst->data;
10781049

10791050
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
1051+
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
1052+
10801053
for (uint32_t i = 0; i < sinfo.size(); ++i) {
1081-
data[s*sinfo.size() + i] = sinfo.idxs[s][i];
1054+
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
10821055
}
10831056
}
10841057
}

0 commit comments

Comments
 (0)