@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
69
69
continue ;
70
70
}
71
71
72
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
73
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
72
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
73
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
74
74
75
75
const char * dev_name = " CPU" ;
76
76
@@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
754
754
// Iterate and write all the keys first, each row is a cell
755
755
// Get whole range at a time
756
756
for (uint32_t il = 0 ; il < n_layer; ++il) {
757
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
757
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
758
758
759
759
// Write key type
760
760
const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
774
774
775
775
if (!v_trans) {
776
776
for (uint32_t il = 0 ; il < n_layer; ++il) {
777
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
777
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
778
778
779
779
// Write value type
780
780
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
795
795
// When v is transposed, we also need the element size and get the element ranges from each row
796
796
const uint32_t kv_size = size;
797
797
for (uint32_t il = 0 ; il < n_layer; ++il) {
798
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
798
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
799
799
800
800
// Write value type
801
801
const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
942
942
943
943
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
944
944
for (uint32_t il = 0 ; il < n_layer; ++il) {
945
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
945
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
946
946
947
947
// Read type of key
948
948
int32_t k_type_i_ref;
@@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
970
970
971
971
if (!v_trans) {
972
972
for (uint32_t il = 0 ; il < n_layer; ++il) {
973
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
973
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
974
974
975
975
// Read type of value
976
976
int32_t v_type_i_ref;
@@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
998
998
} else {
999
999
// For each layer, read the values for each cell (transposed)
1000
1000
for (uint32_t il = 0 ; il < n_layer; ++il) {
1001
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1001
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
1002
1002
1003
1003
// Read type of value
1004
1004
int32_t v_type_i_ref;
0 commit comments