Skip to content

Commit 75e7e8c

Browse files
committed
kv-cache : prepare K/V buffer for separation
ggml-ci
1 parent e9b6a01 commit 75e7e8c

File tree

1 file changed

+51
-31
lines changed

1 file changed

+51
-31
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
9898
ggml_tensor * k;
9999
ggml_tensor * v;
100100

101-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
101+
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1);
102+
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1);
103103

104104
ggml_format_name(k, "cache_k_l%d", il);
105105
ggml_format_name(v, "cache_v_l%d", il);
@@ -785,33 +785,40 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
785785

786786
auto * k = layers[ikv].k;
787787

788-
return ggml_view_3d(ctx, k,
789-
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
788+
const uint64_t kv_size = get_size();
789+
790+
return ggml_view_4d(ctx, k,
791+
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1,
790792
ggml_row_size(k->type, hparams.n_embd_head_k),
791793
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
792-
0);
794+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size),
795+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*0);
793796
}
794797

795798
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
796799
const int32_t ikv = map_layer_ids.at(il);
797800

798801
auto * v = layers[ikv].v;
799802

803+
const uint64_t kv_size = get_size();
804+
800805
if (!v_trans) {
801806
// note: v->nb[1] <= v->nb[2]
802-
return ggml_view_3d(ctx, v,
803-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
804-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
805-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
806-
0);
807+
return ggml_view_4d(ctx, v,
808+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
809+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
810+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
811+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size), // v->nb[3]
812+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size)*0);
807813
}
808814

809815
// note: v->nb[1] > v->nb[2]
810-
return ggml_view_3d(ctx, v,
811-
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
812-
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
813-
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
814-
0);
816+
return ggml_view_4d(ctx, v,
817+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
818+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
819+
ggml_row_size(v->type, kv_size), // v->nb[2]
820+
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
821+
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il))*0);
815822
}
816823

817824
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -850,24 +857,16 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
850857

851858
if (v_idxs && supports_set_rows) {
852859
if (!v_trans) {
860+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
861+
853862
return ggml_set_rows(ctx, v, v_cur, v_idxs);
854863
}
855864

856865
// the row becomes a single element
857-
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
858-
859-
// note: the V cache is transposed when not using flash attention
860-
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
866+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
861867

862-
// note: we can be more explicit here at the cost of extra cont
863-
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
864-
//v_cur = ggml_transpose(ctx, v_cur);
865-
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
868+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
866869

867-
// we broadcast the KV indices n_embd_v_gqa times
868-
// v [1, n_kv, n_embd_v_gqa]
869-
// v_cur [1, n_tokens, n_embd_v_gqa]
870-
// v_idxs [n_tokens, 1, 1]
871870
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872871
}
873872

@@ -904,7 +903,14 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904903
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905904
const uint32_t n_tokens = ubatch.n_tokens;
906905

907-
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
906+
ggml_tensor * v_idxs;
907+
908+
if (!v_trans) {
909+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
910+
} else {
911+
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
912+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
913+
}
908914

909915
ggml_set_input(v_idxs);
910916

@@ -921,7 +927,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
921927
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922928
int64_t * data = (int64_t *) dst->data;
923929

924-
for (int64_t i = 0; i < n_tokens; ++i) {
930+
for (uint32_t i = 0; i < n_tokens; ++i) {
925931
data[i] = sinfo.idxs.at(i);
926932
}
927933
}
@@ -936,8 +942,22 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
936942
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937943
int64_t * data = (int64_t *) dst->data;
938944

939-
for (int64_t i = 0; i < n_tokens; ++i) {
940-
data[i] = sinfo.idxs.at(i);
945+
if (!v_trans) {
946+
for (uint32_t i = 0; i < n_tokens; ++i) {
947+
data[i] = sinfo.idxs.at(i);
948+
}
949+
} else {
950+
// note: the V cache is transposed when not using flash attention
951+
const int64_t kv_size = get_size();
952+
953+
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
954+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
955+
956+
for (uint32_t i = 0; i < n_tokens; ++i) {
957+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
958+
data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i);
959+
}
960+
}
941961
}
942962
}
943963

0 commit comments

Comments
 (0)