Skip to content

Commit b123d89

Browse files
committed
kv-cache : prepare K/V buffers for separation
ggml-ci
1 parent 67d1ef2 commit b123d89

File tree

3 files changed

+127
-36
lines changed

3 files changed

+127
-36
lines changed

src/llama-hparams.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68+
bool llama_hparams::is_n_embd_k_gqa_variable() const {
69+
const uint32_t val = n_embd_k_gqa();
70+
for (uint32_t il = 0; il < n_layer; ++il) {
71+
if (val != n_embd_k_gqa(il)) {
72+
return true;
73+
}
74+
}
75+
76+
return false;
77+
}
78+
79+
bool llama_hparams::is_n_embd_v_gqa_variable() const {
80+
const uint32_t val = n_embd_v_gqa();
81+
for (uint32_t il = 0; il < n_layer; ++il) {
82+
if (val != n_embd_v_gqa(il)) {
83+
return true;
84+
}
85+
}
86+
87+
return false;
88+
}
89+
90+
uint32_t llama_hparams::n_embd_k_gqa_max() const {
91+
uint32_t val = n_embd_k_gqa();
92+
for (uint32_t il = 0; il < n_layer; ++il) {
93+
val = std::max(val, n_embd_k_gqa(il));
94+
}
95+
96+
return val;
97+
}
98+
99+
uint32_t llama_hparams::n_embd_v_gqa_max() const {
100+
uint32_t val = n_embd_v_gqa();
101+
for (uint32_t il = 0; il < n_layer; ++il) {
102+
val = std::max(val, n_embd_v_gqa(il));
103+
}
104+
105+
return val;
106+
}
107+
68108
uint32_t llama_hparams::n_embd_r() const {
69109
if (wkv_head_size != 0) {
70110
// for RWKV models

src/llama-hparams.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ struct llama_hparams {
189189
// dimension of value embeddings across all k-v heads
190190
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
191191

192+
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
193+
bool is_n_embd_k_gqa_variable() const;
194+
bool is_n_embd_v_gqa_variable() const;
195+
196+
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
197+
uint32_t n_embd_k_gqa_max() const;
198+
uint32_t n_embd_v_gqa_max() const;
199+
192200
// dimension of the rolling state embeddings
193201
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
194202
uint32_t n_embd_r() const;

src/llama-kv-cache-unified.cpp

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,21 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868

6969
cells.resize(kv_size);
7070

71+
// [TAG_V_CACHE_VARIABLE]
72+
if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
73+
LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
74+
__func__, hparams.n_embd_v_gqa_max());
75+
}
76+
7177
for (uint32_t il = 0; il < n_layer_cache; il++) {
7278
if (filter && !filter(il)) {
7379
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
7480
continue;
7581
}
7682

77-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
78-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
83+
// [TAG_V_CACHE_VARIABLE]
84+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
85+
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
7986

8087
const char * dev_name = "CPU";
8188

@@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98105
ggml_tensor * k;
99106
ggml_tensor * v;
100107

101-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
108+
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1);
109+
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1);
103110

104111
ggml_format_name(k, "cache_k_l%d", il);
105112
ggml_format_name(v, "cache_v_l%d", il);
@@ -785,33 +792,47 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
785792

786793
auto * k = layers[ikv].k;
787794

788-
return ggml_view_3d(ctx, k,
789-
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
795+
const uint64_t kv_size = get_size();
796+
const uint64_t n_embd_k_gqa = k->ne[0];
797+
798+
assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
799+
800+
return ggml_view_4d(ctx, k,
801+
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1,
790802
ggml_row_size(k->type, hparams.n_embd_head_k),
791-
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
792-
0);
803+
ggml_row_size(k->type, n_embd_k_gqa),
804+
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
805+
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*0);
793806
}
794807

795808
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
796809
const int32_t ikv = map_layer_ids.at(il);
797810

798811
auto * v = layers[ikv].v;
799812

813+
const uint64_t kv_size = get_size();
814+
const uint64_t n_embd_v_gqa = v->ne[0];
815+
816+
// [TAG_V_CACHE_VARIABLE]
817+
assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
818+
800819
if (!v_trans) {
801820
// note: v->nb[1] <= v->nb[2]
802-
return ggml_view_3d(ctx, v,
803-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
804-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
805-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
806-
0);
821+
return ggml_view_4d(ctx, v,
822+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
823+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
824+
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
825+
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
826+
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*0);
807827
}
808828

809829
// note: v->nb[1] > v->nb[2]
810-
return ggml_view_3d(ctx, v,
811-
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
812-
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
813-
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
814-
0);
830+
return ggml_view_4d(ctx, v,
831+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
832+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
833+
ggml_row_size(v->type, kv_size), // v->nb[2]
834+
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
835+
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*0);
815836
}
816837

817838
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,6 +846,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
825846
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826847

827848
if (k_idxs && supports_set_rows) {
849+
if (k->ne[2] > 1) {
850+
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
851+
}
852+
828853
return ggml_set_rows(ctx, k, k_cur, k_idxs);
829854
}
830855

@@ -843,31 +868,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
843868

844869
auto * v = layers[ikv].v;
845870

846-
const int64_t n_embd_v_gqa = v->ne[0];
847-
const int64_t n_tokens = v_cur->ne[2];
871+
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
872+
const int64_t n_tokens = v_cur->ne[2];
848873

849874
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850875

851876
if (v_idxs && supports_set_rows) {
852877
if (!v_trans) {
878+
if (v->ne[2] > 1) {
879+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
880+
}
881+
853882
return ggml_set_rows(ctx, v, v_cur, v_idxs);
854883
}
855884

856-
// the row becomes a single element
857-
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
885+
// [TAG_V_CACHE_VARIABLE]
886+
if (n_embd_v_gqa < v->ne[0]) {
887+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
888+
}
858889

859-
// note: the V cache is transposed when not using flash attention
860-
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
890+
// the row becomes a single element
891+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
861892

862-
// note: we can be more explicit here at the cost of extra cont
863-
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
864-
//v_cur = ggml_transpose(ctx, v_cur);
865-
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
893+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
866894

867-
// we broadcast the KV indices n_embd_v_gqa times
868-
// v [1, n_kv, n_embd_v_gqa]
869-
// v_cur [1, n_tokens, n_embd_v_gqa]
870-
// v_idxs [n_tokens, 1, 1]
871895
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872896
}
873897

@@ -904,7 +928,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904928
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905929
const uint32_t n_tokens = ubatch.n_tokens;
906930

907-
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
931+
ggml_tensor * v_idxs;
932+
933+
if (!v_trans) {
934+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
935+
} else {
936+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
937+
}
908938

909939
ggml_set_input(v_idxs);
910940

@@ -921,7 +951,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
921951
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922952
int64_t * data = (int64_t *) dst->data;
923953

924-
for (int64_t i = 0; i < n_tokens; ++i) {
954+
for (uint32_t i = 0; i < n_tokens; ++i) {
925955
data[i] = sinfo.idxs.at(i);
926956
}
927957
}
@@ -936,8 +966,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
936966
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937967
int64_t * data = (int64_t *) dst->data;
938968

939-
for (int64_t i = 0; i < n_tokens; ++i) {
940-
data[i] = sinfo.idxs.at(i);
969+
if (!v_trans) {
970+
for (uint32_t i = 0; i < n_tokens; ++i) {
971+
data[i] = sinfo.idxs.at(i);
972+
}
973+
} else {
974+
// note: the V cache is transposed when not using flash attention
975+
const int64_t kv_size = get_size();
976+
977+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
978+
979+
for (uint32_t i = 0; i < n_tokens; ++i) {
980+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
981+
data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i);
982+
}
983+
}
941984
}
942985
}
943986

0 commit comments

Comments
 (0)