@@ -68,14 +68,21 @@ llama_kv_cache_unified::llama_kv_cache_unified(
68
68
69
69
cells.resize (kv_size);
70
70
71
+ // [TAG_V_CACHE_VARIABLE]
72
+ if (v_trans && hparams.is_n_embd_v_gqa_variable ()) {
73
+ LLAMA_LOG_WARN (" %s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n " ,
74
+ __func__, hparams.n_embd_v_gqa_max ());
75
+ }
76
+
71
77
for (uint32_t il = 0 ; il < n_layer_cache; il++) {
72
78
if (filter && !filter (il)) {
73
79
LLAMA_LOG_DEBUG (" %s: layer %3d: skipped\n " , __func__, il);
74
80
continue ;
75
81
}
76
82
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
83
+ // [TAG_V_CACHE_VARIABLE]
84
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
85
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa (il) : hparams.n_embd_v_gqa_max ();
79
86
80
87
const char * dev_name = " CPU" ;
81
88
@@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
105
ggml_tensor * k;
99
106
ggml_tensor * v;
100
107
101
- k = ggml_new_tensor_2d (ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
108
+ k = ggml_new_tensor_3d (ctx, type_k, n_embd_k_gqa, kv_size, 1 );
109
+ v = ggml_new_tensor_3d (ctx, type_v, n_embd_v_gqa, kv_size, 1 );
103
110
104
111
ggml_format_name (k, " cache_k_l%d" , il);
105
112
ggml_format_name (v, " cache_v_l%d" , il);
@@ -785,33 +792,47 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
785
792
786
793
auto * k = layers[ikv].k ;
787
794
788
- return ggml_view_3d (ctx, k,
789
- hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv,
795
+ const uint64_t kv_size = get_size ();
796
+ const uint64_t n_embd_k_gqa = k->ne [0 ];
797
+
798
+ assert (n_embd_k_gqa == hparams.n_embd_k_gqa (il));
799
+
800
+ return ggml_view_4d (ctx, k,
801
+ hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv, 1 ,
790
802
ggml_row_size (k->type , hparams.n_embd_head_k ),
791
- ggml_row_size (k->type , hparams.n_embd_k_gqa (il)),
792
- 0 );
803
+ ggml_row_size (k->type , n_embd_k_gqa),
804
+ ggml_row_size (k->type , n_embd_k_gqa*kv_size),
805
+ ggml_row_size (k->type , n_embd_k_gqa*kv_size)*0 );
793
806
}
794
807
795
808
ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t il, uint32_t n_kv) const {
796
809
const int32_t ikv = map_layer_ids.at (il);
797
810
798
811
auto * v = layers[ikv].v ;
799
812
813
+ const uint64_t kv_size = get_size ();
814
+ const uint64_t n_embd_v_gqa = v->ne [0 ];
815
+
816
+ // [TAG_V_CACHE_VARIABLE]
817
+ assert (n_embd_v_gqa >= hparams.n_embd_v_gqa (il));
818
+
800
819
if (!v_trans) {
801
820
// note: v->nb[1] <= v->nb[2]
802
- return ggml_view_3d (ctx, v,
803
- hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
804
- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
805
- ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
806
- 0 );
821
+ return ggml_view_4d (ctx, v,
822
+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
823
+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
824
+ ggml_row_size (v->type , n_embd_v_gqa), // v->nb[2]
825
+ ggml_row_size (v->type , n_embd_v_gqa*kv_size), // v->nb[3]
826
+ ggml_row_size (v->type , n_embd_v_gqa*kv_size)*0 );
807
827
}
808
828
809
829
// note: v->nb[1] > v->nb[2]
810
- return ggml_view_3d (ctx, v,
811
- n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v ,
812
- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
813
- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
814
- 0 );
830
+ return ggml_view_4d (ctx, v,
831
+ n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
832
+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
833
+ ggml_row_size (v->type , kv_size), // v->nb[2]
834
+ ggml_row_size (v->type , kv_size*n_embd_v_gqa), // v->nb[3]
835
+ ggml_row_size (v->type , kv_size*n_embd_v_gqa)*0 );
815
836
}
816
837
817
838
ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,6 +846,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
825
846
k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
826
847
827
848
if (k_idxs && supports_set_rows) {
849
+ if (k->ne [2 ] > 1 ) {
850
+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
851
+ }
852
+
828
853
return ggml_set_rows (ctx, k, k_cur, k_idxs);
829
854
}
830
855
@@ -843,31 +868,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
843
868
844
869
auto * v = layers[ikv].v ;
845
870
846
- const int64_t n_embd_v_gqa = v ->ne [0 ];
847
- const int64_t n_tokens = v_cur->ne [2 ];
871
+ const int64_t n_embd_v_gqa = v_cur ->ne [0 ]*v_cur-> ne [ 1 ];
872
+ const int64_t n_tokens = v_cur->ne [2 ];
848
873
849
874
v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
850
875
851
876
if (v_idxs && supports_set_rows) {
852
877
if (!v_trans) {
878
+ if (v->ne [2 ] > 1 ) {
879
+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
880
+ }
881
+
853
882
return ggml_set_rows (ctx, v, v_cur, v_idxs);
854
883
}
855
884
856
- // the row becomes a single element
857
- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
885
+ // [TAG_V_CACHE_VARIABLE]
886
+ if (n_embd_v_gqa < v->ne [0 ]) {
887
+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
888
+ }
858
889
859
- // note: the V cache is transposed when not using flash attention
860
- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur ->ne [0 ], 1 , v_cur ->ne [1 ]), 2 , 0 , 1 , 3 );
890
+ // the row becomes a single element
891
+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v ->ne [0 ]*v ->ne [1 ]*v-> ne [ 2 ] );
861
892
862
- // note: we can be more explicit here at the cost of extra cont
863
- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
- // v_cur = ggml_transpose(ctx, v_cur);
865
- // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
893
+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur->ne [0 ]*v_cur->ne [1 ]);
866
894
867
- // we broadcast the KV indices n_embd_v_gqa times
868
- // v [1, n_kv, n_embd_v_gqa]
869
- // v_cur [1, n_tokens, n_embd_v_gqa]
870
- // v_idxs [n_tokens, 1, 1]
871
895
return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
872
896
}
873
897
@@ -904,7 +928,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904
928
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
905
929
const uint32_t n_tokens = ubatch.n_tokens ;
906
930
907
- ggml_tensor * v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
931
+ ggml_tensor * v_idxs;
932
+
933
+ if (!v_trans) {
934
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
935
+ } else {
936
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max ());
937
+ }
908
938
909
939
ggml_set_input (v_idxs);
910
940
@@ -921,7 +951,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
921
951
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
922
952
int64_t * data = (int64_t *) dst->data ;
923
953
924
- for (int64_t i = 0 ; i < n_tokens; ++i) {
954
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
925
955
data[i] = sinfo.idxs .at (i);
926
956
}
927
957
}
@@ -936,8 +966,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
936
966
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
937
967
int64_t * data = (int64_t *) dst->data ;
938
968
939
- for (int64_t i = 0 ; i < n_tokens; ++i) {
940
- data[i] = sinfo.idxs .at (i);
969
+ if (!v_trans) {
970
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
971
+ data[i] = sinfo.idxs .at (i);
972
+ }
973
+ } else {
974
+ // note: the V cache is transposed when not using flash attention
975
+ const int64_t kv_size = get_size ();
976
+
977
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max ();
978
+
979
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
980
+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
981
+ data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs .at (i);
982
+ }
983
+ }
941
984
}
942
985
}
943
986
0 commit comments