Skip to content

Commit 5d0f877

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent eed7811 commit 5d0f877

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
72+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
73+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
7474

7575
const char * dev_name = "CPU";
7676

@@ -756,7 +756,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
756756
// Iterate and write all the keys first, each row is a cell
757757
// Get whole range at a time
758758
for (uint32_t il = 0; il < n_layer; ++il) {
759-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
759+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
760760

761761
// Write key type
762762
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -776,7 +776,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
776776

777777
if (!v_trans) {
778778
for (uint32_t il = 0; il < n_layer; ++il) {
779-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
779+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
780780

781781
// Write value type
782782
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -797,7 +797,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
797797
// When v is transposed, we also need the element size and get the element ranges from each row
798798
const uint32_t kv_size = size;
799799
for (uint32_t il = 0; il < n_layer; ++il) {
800-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
800+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
801801

802802
// Write value type
803803
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -944,7 +944,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
944944

945945
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
946946
for (uint32_t il = 0; il < n_layer; ++il) {
947-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
947+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
948948

949949
// Read type of key
950950
int32_t k_type_i_ref;
@@ -972,7 +972,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
972972

973973
if (!v_trans) {
974974
for (uint32_t il = 0; il < n_layer; ++il) {
975-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
975+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
976976

977977
// Read type of value
978978
int32_t v_type_i_ref;
@@ -1000,7 +1000,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10001000
} else {
10011001
// For each layer, read the values for each cell (transposed)
10021002
for (uint32_t il = 0; il < n_layer; ++il) {
1003-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1003+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
10041004

10051005
// Read type of value
10061006
int32_t v_type_i_ref;

src/llama-kv-cache-unified.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
continue;
6969
}
7070

71-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
72-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
71+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
72+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7373

7474
const char * dev_name = "CPU";
7575

@@ -1365,7 +1365,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13651365
for (const auto & layer : layers) {
13661366
const uint32_t il = layer.il;
13671367

1368-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1368+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13691369

13701370
// Write key type
13711371
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1387,7 +1387,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13871387
for (const auto & layer : layers) {
13881388
const uint32_t il = layer.il;
13891389

1390-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1390+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13911391

13921392
// Write value type
13931393
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1411,7 +1411,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14111411
for (const auto & layer : layers) {
14121412
const uint32_t il = layer.il;
14131413

1414-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1414+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14151415

14161416
// Write value type
14171417
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1554,7 +1554,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15541554
for (const auto & layer : layers) {
15551555
const uint32_t il = layer.il;
15561556

1557-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1557+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15581558

15591559
// Read type of key
15601560
int32_t k_type_i_ref;
@@ -1584,7 +1584,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15841584
for (const auto & layer : layers) {
15851585
const uint32_t il = layer.il;
15861586

1587-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1587+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15881588

15891589
// Read type of value
15901590
int32_t v_type_i_ref;
@@ -1614,7 +1614,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16141614
for (const auto & layer : layers) {
16151615
const uint32_t il = layer.il;
16161616

1617-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1617+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16181618

16191619
// Read type of value
16201620
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)