Skip to content

Commit 40f8c48

Browse files
committed
kv-cache : prepare K/V buffers for separation
ggml-ci
1 parent 7b63a71 commit 40f8c48

File tree

3 files changed

+88
-31
lines changed

3 files changed

+88
-31
lines changed

src/llama-hparams.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,28 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68+
bool llama_hparams::is_n_embd_k_gqa_homogeneous() const {
69+
uint32_t val = n_embd_k_gqa();
70+
for (uint32_t il = 0; il < n_layer; ++il) {
71+
if (val != n_embd_k_gqa(il)) {
72+
return false;
73+
}
74+
}
75+
76+
return true;
77+
}
78+
79+
bool llama_hparams::is_n_embd_v_gqa_homogeneous() const {
80+
uint32_t val = n_embd_v_gqa();
81+
for (uint32_t il = 0; il < n_layer; ++il) {
82+
if (val != n_embd_v_gqa(il)) {
83+
return false;
84+
}
85+
}
86+
87+
return true;
88+
}
89+
6890
uint32_t llama_hparams::n_embd_r() const {
6991
if (wkv_head_size != 0) {
7092
// for RWKV models

src/llama-hparams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ struct llama_hparams {
189189
// dimension of value embeddings across all k-v heads
190190
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
191191

192+
// true if all layers have the same n_embd_k_gqa/n_embd_v_gqa
193+
bool is_n_embd_k_gqa_homogeneous() const;
194+
bool is_n_embd_v_gqa_homogeneous() const;
195+
192196
// dimension of the rolling state embeddings
193197
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
194198
uint32_t n_embd_r() const;

src/llama-kv-cache-unified.cpp

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868

6969
cells.resize(kv_size);
7070

71+
if (supports_set_rows) {
72+
// TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73+
// model that needs this
74+
// ref: https://github.com/ggml-org/llama.cpp/pull/14517
75+
GGML_ASSERT(hparams.is_n_embd_v_gqa_homogeneous());
76+
}
77+
7178
for (uint32_t il = 0; il < n_layer_cache; il++) {
7279
if (filter && !filter(il)) {
7380
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
@@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98105
ggml_tensor * k;
99106
ggml_tensor * v;
100107

101-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
108+
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1);
109+
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1);
103110

104111
ggml_format_name(k, "cache_k_l%d", il);
105112
ggml_format_name(v, "cache_v_l%d", il);
@@ -780,33 +787,40 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
780787

781788
auto * k = layers[ikv].k;
782789

783-
return ggml_view_3d(ctx, k,
784-
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
790+
const uint64_t kv_size = get_size();
791+
792+
return ggml_view_4d(ctx, k,
793+
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1,
785794
ggml_row_size(k->type, hparams.n_embd_head_k),
786795
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
787-
0);
796+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size),
797+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*0);
788798
}
789799

790800
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
791801
const int32_t ikv = map_layer_ids.at(il);
792802

793803
auto * v = layers[ikv].v;
794804

805+
const uint64_t kv_size = get_size();
806+
795807
if (!v_trans) {
796808
// note: v->nb[1] <= v->nb[2]
797-
return ggml_view_3d(ctx, v,
798-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
799-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
800-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
801-
0);
809+
return ggml_view_4d(ctx, v,
810+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
811+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
812+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
813+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size), // v->nb[3]
814+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*kv_size)*0);
802815
}
803816

804817
// note: v->nb[1] > v->nb[2]
805-
return ggml_view_3d(ctx, v,
806-
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
807-
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
808-
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
809-
0);
818+
return ggml_view_4d(ctx, v,
819+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
820+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
821+
ggml_row_size(v->type, kv_size), // v->nb[2]
822+
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
823+
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il))*0);
810824
}
811825

812826
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -820,6 +834,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
820834
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
821835

822836
if (k_idxs && supports_set_rows) {
837+
if (k->ne[2] > 1) {
838+
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
839+
}
840+
823841
return ggml_set_rows(ctx, k, k_cur, k_idxs);
824842
}
825843

@@ -845,24 +863,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
845863

846864
if (v_idxs && supports_set_rows) {
847865
if (!v_trans) {
866+
if (v->ne[2] > 1) {
867+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
868+
}
869+
848870
return ggml_set_rows(ctx, v, v_cur, v_idxs);
849871
}
850872

851873
// the row becomes a single element
852-
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
853-
854-
// note: the V cache is transposed when not using flash attention
855-
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
874+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
856875

857-
// note: we can be more explicit here at the cost of extra cont
858-
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
859-
//v_cur = ggml_transpose(ctx, v_cur);
860-
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
876+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
861877

862-
// we broadcast the KV indices n_embd_v_gqa times
863-
// v [1, n_kv, n_embd_v_gqa]
864-
// v_cur [1, n_tokens, n_embd_v_gqa]
865-
// v_idxs [n_tokens, 1, 1]
866878
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
867879
}
868880

@@ -899,7 +911,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
899911
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
900912
const uint32_t n_tokens = ubatch.n_tokens;
901913

902-
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
914+
ggml_tensor * v_idxs;
915+
916+
if (!v_trans) {
917+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
918+
} else {
919+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
920+
}
903921

904922
ggml_set_input(v_idxs);
905923

@@ -916,7 +934,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
916934
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
917935
int64_t * data = (int64_t *) dst->data;
918936

919-
for (int64_t i = 0; i < n_tokens; ++i) {
937+
for (uint32_t i = 0; i < n_tokens; ++i) {
920938
data[i] = sinfo.idxs.at(i);
921939
}
922940
}
@@ -931,8 +949,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
931949
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
932950
int64_t * data = (int64_t *) dst->data;
933951

934-
for (int64_t i = 0; i < n_tokens; ++i) {
935-
data[i] = sinfo.idxs.at(i);
952+
if (!v_trans) {
953+
for (uint32_t i = 0; i < n_tokens; ++i) {
954+
data[i] = sinfo.idxs.at(i);
955+
}
956+
} else {
957+
// note: the V cache is transposed when not using flash attention
958+
const int64_t kv_size = get_size();
959+
960+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
961+
962+
for (uint32_t i = 0; i < n_tokens; ++i) {
963+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
964+
data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i);
965+
}
966+
}
936967
}
937968
}
938969

0 commit comments

Comments
 (0)