@@ -68,8 +68,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
68
68
continue ;
69
69
}
70
70
71
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
72
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
71
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
72
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
73
73
74
74
const char * dev_name = " CPU" ;
75
75
@@ -771,7 +771,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
771
771
// Iterate and write all the keys first, each row is a cell
772
772
// Get whole range at a time
773
773
for (uint32_t il = 0 ; il < n_layer; ++il) {
774
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
774
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
775
775
776
776
// Write key type
777
777
const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -791,7 +791,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
791
791
792
792
if (!v_trans) {
793
793
for (uint32_t il = 0 ; il < n_layer; ++il) {
794
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
794
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
795
795
796
796
// Write value type
797
797
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -812,7 +812,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
812
812
// When v is transposed, we also need the element size and get the element ranges from each row
813
813
const uint32_t kv_size = size;
814
814
for (uint32_t il = 0 ; il < n_layer; ++il) {
815
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
815
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
816
816
817
817
// Write value type
818
818
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -959,7 +959,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
959
959
960
960
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
961
961
for (uint32_t il = 0 ; il < n_layer; ++il) {
962
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
962
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
963
963
964
964
// Read type of key
965
965
int32_t k_type_i_ref;
@@ -987,7 +987,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
987
987
988
988
if (!v_trans) {
989
989
for (uint32_t il = 0 ; il < n_layer; ++il) {
990
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
990
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
991
991
992
992
// Read type of value
993
993
int32_t v_type_i_ref;
@@ -1015,7 +1015,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1015
1015
} else {
1016
1016
// For each layer, read the values for each cell (transposed)
1017
1017
for (uint32_t il = 0 ; il < n_layer; ++il) {
1018
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1018
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
1019
1019
1020
1020
// Read type of value
1021
1021
int32_t v_type_i_ref;
0 commit comments