Skip to content

Commit 1ff97ad

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 8edfe63 commit 1ff97ad

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
72+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
73+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
7474

7575
const char * dev_name = "CPU";
7676

@@ -768,7 +768,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
768768
// Iterate and write all the keys first, each row is a cell
769769
// Get whole range at a time
770770
for (uint32_t il = 0; il < n_layer; ++il) {
771-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
771+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
772772

773773
// Write key type
774774
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -788,7 +788,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
788788

789789
if (!v_trans) {
790790
for (uint32_t il = 0; il < n_layer; ++il) {
791-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
791+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
792792

793793
// Write value type
794794
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -809,7 +809,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
809809
// When v is transposed, we also need the element size and get the element ranges from each row
810810
const uint32_t kv_size = size;
811811
for (uint32_t il = 0; il < n_layer; ++il) {
812-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
812+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
813813

814814
// Write value type
815815
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -956,7 +956,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
956956

957957
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
958958
for (uint32_t il = 0; il < n_layer; ++il) {
959-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
959+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
960960

961961
// Read type of key
962962
int32_t k_type_i_ref;
@@ -984,7 +984,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
984984

985985
if (!v_trans) {
986986
for (uint32_t il = 0; il < n_layer; ++il) {
987-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
987+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
988988

989989
// Read type of value
990990
int32_t v_type_i_ref;
@@ -1012,7 +1012,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10121012
} else {
10131013
// For each layer, read the values for each cell (transposed)
10141014
for (uint32_t il = 0; il < n_layer; ++il) {
1015-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1015+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
10161016

10171017
// Read type of value
10181018
int32_t v_type_i_ref;

src/llama-kv-cache-unified.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
continue;
6969
}
7070

71-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
72-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
71+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
72+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7373

7474
const char * dev_name = "CPU";
7575

@@ -1367,7 +1367,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13671367
for (const auto & layer : layers) {
13681368
const uint32_t il = layer.il;
13691369

1370-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1370+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13711371

13721372
// Write key type
13731373
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1389,7 +1389,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13891389
for (const auto & layer : layers) {
13901390
const uint32_t il = layer.il;
13911391

1392-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1392+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13931393

13941394
// Write value type
13951395
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1413,7 +1413,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14131413
for (const auto & layer : layers) {
14141414
const uint32_t il = layer.il;
14151415

1416-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1416+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14171417

14181418
// Write value type
14191419
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1556,7 +1556,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15561556
for (const auto & layer : layers) {
15571557
const uint32_t il = layer.il;
15581558

1559-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1559+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15601560

15611561
// Read type of key
15621562
int32_t k_type_i_ref;
@@ -1586,7 +1586,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15861586
for (const auto & layer : layers) {
15871587
const uint32_t il = layer.il;
15881588

1589-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1589+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15901590

15911591
// Read type of value
15921592
int32_t v_type_i_ref;
@@ -1616,7 +1616,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16161616
for (const auto & layer : layers) {
16171617
const uint32_t il = layer.il;
16181618

1619-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1619+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16201620

16211621
// Read type of value
16221622
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)