@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
69
69
continue ;
70
70
}
71
71
72
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
73
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
72
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
73
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
74
74
75
75
const char * dev_name = " CPU" ;
76
76
@@ -768,7 +768,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
768
768
// Iterate and write all the keys first, each row is a cell
769
769
// Get whole range at a time
770
770
for (uint32_t il = 0 ; il < n_layer; ++il) {
771
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
771
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
772
772
773
773
// Write key type
774
774
const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -788,7 +788,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
788
788
789
789
if (!v_trans) {
790
790
for (uint32_t il = 0 ; il < n_layer; ++il) {
791
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
791
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
792
792
793
793
// Write value type
794
794
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -809,7 +809,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
809
809
// When v is transposed, we also need the element size and get the element ranges from each row
810
810
const uint32_t kv_size = size;
811
811
for (uint32_t il = 0 ; il < n_layer; ++il) {
812
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
812
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
813
813
814
814
// Write value type
815
815
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -956,7 +956,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
956
956
957
957
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
958
958
for (uint32_t il = 0 ; il < n_layer; ++il) {
959
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
959
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
960
960
961
961
// Read type of key
962
962
int32_t k_type_i_ref;
@@ -984,7 +984,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
984
984
985
985
if (!v_trans) {
986
986
for (uint32_t il = 0 ; il < n_layer; ++il) {
987
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
987
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
988
988
989
989
// Read type of value
990
990
int32_t v_type_i_ref;
@@ -1012,7 +1012,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1012
1012
} else {
1013
1013
// For each layer, read the values for each cell (transposed)
1014
1014
for (uint32_t il = 0 ; il < n_layer; ++il) {
1015
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1015
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
1016
1016
1017
1017
// Read type of value
1018
1018
int32_t v_type_i_ref;
0 commit comments