Skip to content

Commit 4dc25e6

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 4db94e9 commit 4dc25e6

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6868
continue;
6969
}
7070

71-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
72-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
71+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
72+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
7373

7474
const char * dev_name = "CPU";
7575

@@ -771,7 +771,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
771771
// Iterate and write all the keys first, each row is a cell
772772
// Get whole range at a time
773773
for (uint32_t il = 0; il < n_layer; ++il) {
774-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
774+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
775775

776776
// Write key type
777777
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -791,7 +791,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
791791

792792
if (!v_trans) {
793793
for (uint32_t il = 0; il < n_layer; ++il) {
794-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
794+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
795795

796796
// Write value type
797797
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -812,7 +812,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
812812
// When v is transposed, we also need the element size and get the element ranges from each row
813813
const uint32_t kv_size = size;
814814
for (uint32_t il = 0; il < n_layer; ++il) {
815-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
815+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
816816

817817
// Write value type
818818
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -959,7 +959,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
959959

960960
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
961961
for (uint32_t il = 0; il < n_layer; ++il) {
962-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
962+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
963963

964964
// Read type of key
965965
int32_t k_type_i_ref;
@@ -987,7 +987,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
987987

988988
if (!v_trans) {
989989
for (uint32_t il = 0; il < n_layer; ++il) {
990-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
990+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
991991

992992
// Read type of value
993993
int32_t v_type_i_ref;
@@ -1015,7 +1015,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10151015
} else {
10161016
// For each layer, read the values for each cell (transposed)
10171017
for (uint32_t il = 0; il < n_layer; ++il) {
1018-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1018+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
10191019

10201020
// Read type of value
10211021
int32_t v_type_i_ref;

src/llama-kv-cache-unified.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6767
continue;
6868
}
6969

70-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
71-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
70+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
71+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7272

7373
const char * dev_name = "CPU";
7474

@@ -1324,7 +1324,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13241324
for (const auto & layer : layers) {
13251325
const uint32_t il = layer.il;
13261326

1327-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1327+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13281328

13291329
// Write key type
13301330
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1346,7 +1346,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13461346
for (const auto & layer : layers) {
13471347
const uint32_t il = layer.il;
13481348

1349-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1349+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13501350

13511351
// Write value type
13521352
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1370,7 +1370,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13701370
for (const auto & layer : layers) {
13711371
const uint32_t il = layer.il;
13721372

1373-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1373+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13741374

13751375
// Write value type
13761376
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1513,7 +1513,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15131513
for (const auto & layer : layers) {
15141514
const uint32_t il = layer.il;
15151515

1516-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1516+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15171517

15181518
// Read type of key
15191519
int32_t k_type_i_ref;
@@ -1543,7 +1543,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15431543
for (const auto & layer : layers) {
15441544
const uint32_t il = layer.il;
15451545

1546-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1546+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15471547

15481548
// Read type of value
15491549
int32_t v_type_i_ref;
@@ -1573,7 +1573,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15731573
for (const auto & layer : layers) {
15741574
const uint32_t il = layer.il;
15751575

1576-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1576+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15771577

15781578
// Read type of value
15791579
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)