Skip to content

Commit 9aa941d

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 3aa78f7 commit 9aa941d

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7575
continue;
7676
}
7777

78-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
79-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
78+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
79+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
8080

8181
const char * dev_name = "CPU";
8282

@@ -1369,7 +1369,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13691369
for (const auto & layer : layers) {
13701370
const uint32_t il = layer.il;
13711371

1372-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1372+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13731373

13741374
// Write key type
13751375
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1391,7 +1391,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13911391
for (const auto & layer : layers) {
13921392
const uint32_t il = layer.il;
13931393

1394-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1394+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13951395

13961396
// Write value type
13971397
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1415,7 +1415,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14151415
for (const auto & layer : layers) {
14161416
const uint32_t il = layer.il;
14171417

1418-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1418+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14191419

14201420
// Write value type
14211421
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1552,7 +1552,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15521552
for (const auto & layer : layers) {
15531553
const uint32_t il = layer.il;
15541554

1555-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1555+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15561556

15571557
// Read type of key
15581558
int32_t k_type_i_ref;
@@ -1582,7 +1582,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15821582
for (const auto & layer : layers) {
15831583
const uint32_t il = layer.il;
15841584

1585-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1585+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15861586

15871587
// Read type of value
15881588
int32_t v_type_i_ref;
@@ -1612,7 +1612,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16121612
for (const auto & layer : layers) {
16131613
const uint32_t il = layer.il;
16141614

1615-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1615+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16161616

16171617
// Read type of value
16181618
int32_t v_type_i_ref;
@@ -1921,8 +1921,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
19211921
continue;
19221922
}
19231923

1924-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1925-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1924+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
1925+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
19261926

19271927
const char * dev_name = "CPU";
19281928

@@ -2649,7 +2649,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26492649
// Iterate and write all the keys first, each row is a cell
26502650
// Get whole range at a time
26512651
for (uint32_t il = 0; il < n_layer; ++il) {
2652-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2652+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
26532653

26542654
// Write key type
26552655
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -2669,7 +2669,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26692669

26702670
if (!v_trans) {
26712671
for (uint32_t il = 0; il < n_layer; ++il) {
2672-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2672+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
26732673

26742674
// Write value type
26752675
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2690,7 +2690,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26902690
// When v is transposed, we also need the element size and get the element ranges from each row
26912691
const uint32_t kv_size = size;
26922692
for (uint32_t il = 0; il < n_layer; ++il) {
2693-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2693+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
26942694

26952695
// Write value type
26962696
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2837,7 +2837,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28372837

28382838
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
28392839
for (uint32_t il = 0; il < n_layer; ++il) {
2840-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2840+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
28412841

28422842
// Read type of key
28432843
int32_t k_type_i_ref;
@@ -2865,7 +2865,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28652865

28662866
if (!v_trans) {
28672867
for (uint32_t il = 0; il < n_layer; ++il) {
2868-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2868+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
28692869

28702870
// Read type of value
28712871
int32_t v_type_i_ref;
@@ -2893,7 +2893,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28932893
} else {
28942894
// For each layer, read the values for each cell (transposed)
28952895
for (uint32_t il = 0; il < n_layer; ++il) {
2896-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2896+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
28972897

28982898
// Read type of value
28992899
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)