Skip to content

Commit 6b58853

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 73f8984 commit 6b58853

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
205205
continue;
206206
}
207207

208-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
209-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
208+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
209+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
210210

211211
const char * dev_name = "CPU";
212212

@@ -1447,7 +1447,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14471447
for (const auto & layer : layers) {
14481448
const uint32_t il = layer.il;
14491449

1450-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1450+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
14511451

14521452
// Write key type
14531453
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1469,7 +1469,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14691469
for (const auto & layer : layers) {
14701470
const uint32_t il = layer.il;
14711471

1472-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1472+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14731473

14741474
// Write value type
14751475
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1493,7 +1493,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14931493
for (const auto & layer : layers) {
14941494
const uint32_t il = layer.il;
14951495

1496-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1496+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14971497

14981498
// Write value type
14991499
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1636,7 +1636,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16361636
for (const auto & layer : layers) {
16371637
const uint32_t il = layer.il;
16381638

1639-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1639+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
16401640

16411641
// Read type of key
16421642
int32_t k_type_i_ref;
@@ -1666,7 +1666,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16661666
for (const auto & layer : layers) {
16671667
const uint32_t il = layer.il;
16681668

1669-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1669+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
16701670

16711671
// Read type of value
16721672
int32_t v_type_i_ref;
@@ -1696,7 +1696,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16961696
for (const auto & layer : layers) {
16971697
const uint32_t il = layer.il;
16981698

1699-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1699+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
17001700

17011701
// Read type of value
17021702
int32_t v_type_i_ref;
@@ -2206,8 +2206,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
22062206
continue;
22072207
}
22082208

2209-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
2210-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
2209+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
2210+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
22112211

22122212
const char * dev_name = "CPU";
22132213

@@ -2909,7 +2909,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
29092909
// Iterate and write all the keys first, each row is a cell
29102910
// Get whole range at a time
29112911
for (uint32_t il = 0; il < n_layer; ++il) {
2912-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2912+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
29132913

29142914
// Write key type
29152915
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -2929,7 +2929,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
29292929

29302930
if (!v_trans) {
29312931
for (uint32_t il = 0; il < n_layer; ++il) {
2932-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2932+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
29332933

29342934
// Write value type
29352935
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2950,7 +2950,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
29502950
// When v is transposed, we also need the element size and get the element ranges from each row
29512951
const uint32_t kv_size = size;
29522952
for (uint32_t il = 0; il < n_layer; ++il) {
2953-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2953+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
29542954

29552955
// Write value type
29562956
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -3097,7 +3097,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
30973097

30983098
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
30993099
for (uint32_t il = 0; il < n_layer; ++il) {
3100-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
3100+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
31013101

31023102
// Read type of key
31033103
int32_t k_type_i_ref;
@@ -3125,7 +3125,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
31253125

31263126
if (!v_trans) {
31273127
for (uint32_t il = 0; il < n_layer; ++il) {
3128-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
3128+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
31293129

31303130
// Read type of value
31313131
int32_t v_type_i_ref;
@@ -3153,7 +3153,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
31533153
} else {
31543154
// For each layer, read the values for each cell (transposed)
31553155
for (uint32_t il = 0; il < n_layer; ++il) {
3156-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
3156+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
31573157

31583158
// Read type of value
31593159
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)