Skip to content

Commit 17abb2b

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent f56561b commit 17abb2b

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
72+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
73+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
7474

7575
const char * dev_name = "CPU";
7676

@@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
754754
// Iterate and write all the keys first, each row is a cell
755755
// Get whole range at a time
756756
for (uint32_t il = 0; il < n_layer; ++il) {
757-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
757+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
758758

759759
// Write key type
760760
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
774774

775775
if (!v_trans) {
776776
for (uint32_t il = 0; il < n_layer; ++il) {
777-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
777+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
778778

779779
// Write value type
780780
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
795795
// When v is transposed, we also need the element size and get the element ranges from each row
796796
const uint32_t kv_size = size;
797797
for (uint32_t il = 0; il < n_layer; ++il) {
798-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
798+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
799799

800800
// Write value type
801801
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
942942

943943
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
944944
for (uint32_t il = 0; il < n_layer; ++il) {
945-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
945+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
946946

947947
// Read type of key
948948
int32_t k_type_i_ref;
@@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
970970

971971
if (!v_trans) {
972972
for (uint32_t il = 0; il < n_layer; ++il) {
973-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
973+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
974974

975975
// Read type of value
976976
int32_t v_type_i_ref;
@@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
998998
} else {
999999
// For each layer, read the values for each cell (transposed)
10001000
for (uint32_t il = 0; il < n_layer; ++il) {
1001-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1001+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
10021002

10031003
// Read type of value
10041004
int32_t v_type_i_ref;

src/llama-kv-cache-unified.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
continue;
6969
}
7070

71-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
72-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
71+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
72+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7373

7474
const char * dev_name = "CPU";
7575

@@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14301430
for (const auto & layer : layers) {
14311431
const uint32_t il = layer.il;
14321432

1433-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1433+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
14341434

14351435
// Write key type
14361436
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14521452
for (const auto & layer : layers) {
14531453
const uint32_t il = layer.il;
14541454

1455-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1455+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14561456

14571457
// Write value type
14581458
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14761476
for (const auto & layer : layers) {
14771477
const uint32_t il = layer.il;
14781478

1479-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1479+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14801480

14811481
// Write value type
14821482
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16211621
for (const auto & layer : layers) {
16221622
const uint32_t il = layer.il;
16231623

1624-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1624+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
16251625

16261626
// Read type of key
16271627
int32_t k_type_i_ref;
@@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16511651
for (const auto & layer : layers) {
16521652
const uint32_t il = layer.il;
16531653

1654-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1654+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16551655

16561656
// Read type of value
16571657
int32_t v_type_i_ref;
@@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16811681
for (const auto & layer : layers) {
16821682
const uint32_t il = layer.il;
16831683

1684-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1684+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16851685

16861686
// Read type of value
16871687
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)