@@ -68,6 +68,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
68
68
69
69
cells.resize (kv_size);
70
70
71
+ if (supports_set_rows) {
72
+ // TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73
+ // model that needs this
74
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14517
75
+ GGML_ASSERT (hparams.is_n_embd_v_gqa_homogeneous ());
76
+ }
77
+
71
78
for (uint32_t il = 0 ; il < n_layer_cache; il++) {
72
79
if (filter && !filter (il)) {
73
80
LLAMA_LOG_DEBUG (" %s: layer %3d: skipped\n " , __func__, il);
@@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
105
ggml_tensor * k;
99
106
ggml_tensor * v;
100
107
101
- k = ggml_new_tensor_2d (ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
108
+ k = ggml_new_tensor_3d (ctx, type_k, n_embd_k_gqa, kv_size, 1 );
109
+ v = ggml_new_tensor_3d (ctx, type_v, n_embd_v_gqa, kv_size, 1 );
103
110
104
111
ggml_format_name (k, " cache_k_l%d" , il);
105
112
ggml_format_name (v, " cache_v_l%d" , il);
@@ -780,33 +787,40 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
780
787
781
788
auto * k = layers[ikv].k ;
782
789
783
- return ggml_view_3d (ctx, k,
784
- hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv,
790
+ const uint64_t kv_size = get_size ();
791
+
792
+ return ggml_view_4d (ctx, k,
793
+ hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv, 1 ,
785
794
ggml_row_size (k->type , hparams.n_embd_head_k ),
786
795
ggml_row_size (k->type , hparams.n_embd_k_gqa (il)),
787
- 0 );
796
+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size),
797
+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size)*0 );
788
798
}
789
799
790
800
ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t il, uint32_t n_kv) const {
791
801
const int32_t ikv = map_layer_ids.at (il);
792
802
793
803
auto * v = layers[ikv].v ;
794
804
805
+ const uint64_t kv_size = get_size ();
806
+
795
807
if (!v_trans) {
796
808
// note: v->nb[1] <= v->nb[2]
797
- return ggml_view_3d (ctx, v,
798
- hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
799
- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
800
- ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
801
- 0 );
809
+ return ggml_view_4d (ctx, v,
810
+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
811
+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
812
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
813
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size), // v->nb[3]
814
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size)*0 );
802
815
}
803
816
804
817
// note: v->nb[1] > v->nb[2]
805
- return ggml_view_3d (ctx, v,
806
- n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v ,
807
- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
808
- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
809
- 0 );
818
+ return ggml_view_4d (ctx, v,
819
+ n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
820
+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
821
+ ggml_row_size (v->type , kv_size), // v->nb[2]
822
+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il)), // v->nb[3]
823
+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il))*0 );
810
824
}
811
825
812
826
ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -820,6 +834,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
820
834
k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
821
835
822
836
if (k_idxs && supports_set_rows) {
837
+ if (k->ne [2 ] > 1 ) {
838
+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
839
+ }
840
+
823
841
return ggml_set_rows (ctx, k, k_cur, k_idxs);
824
842
}
825
843
@@ -845,24 +863,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
845
863
846
864
if (v_idxs && supports_set_rows) {
847
865
if (!v_trans) {
866
+ if (v->ne [2 ] > 1 ) {
867
+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
868
+ }
869
+
848
870
return ggml_set_rows (ctx, v, v_cur, v_idxs);
849
871
}
850
872
851
873
// the row becomes a single element
852
- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
853
-
854
- // note: the V cache is transposed when not using flash attention
855
- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
874
+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
856
875
857
- // note: we can be more explicit here at the cost of extra cont
858
- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
859
- // v_cur = ggml_transpose(ctx, v_cur);
860
- // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
876
+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur->ne [0 ]*v_cur->ne [1 ]);
861
877
862
- // we broadcast the KV indices n_embd_v_gqa times
863
- // v [1, n_kv, n_embd_v_gqa]
864
- // v_cur [1, n_tokens, n_embd_v_gqa]
865
- // v_idxs [n_tokens, 1, 1]
866
878
return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
867
879
}
868
880
@@ -899,7 +911,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
899
911
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
900
912
const uint32_t n_tokens = ubatch.n_tokens ;
901
913
902
- ggml_tensor * v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
914
+ ggml_tensor * v_idxs;
915
+
916
+ if (!v_trans) {
917
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
918
+ } else {
919
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
920
+ }
903
921
904
922
ggml_set_input (v_idxs);
905
923
@@ -916,7 +934,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
916
934
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
917
935
int64_t * data = (int64_t *) dst->data ;
918
936
919
- for (int64_t i = 0 ; i < n_tokens; ++i) {
937
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
920
938
data[i] = sinfo.idxs .at (i);
921
939
}
922
940
}
@@ -931,8 +949,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
931
949
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
932
950
int64_t * data = (int64_t *) dst->data ;
933
951
934
- for (int64_t i = 0 ; i < n_tokens; ++i) {
935
- data[i] = sinfo.idxs .at (i);
952
+ if (!v_trans) {
953
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
954
+ data[i] = sinfo.idxs .at (i);
955
+ }
956
+ } else {
957
+ // note: the V cache is transposed when not using flash attention
958
+ const int64_t kv_size = get_size ();
959
+
960
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa ();
961
+
962
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
963
+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
964
+ data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs .at (i);
965
+ }
966
+ }
936
967
}
937
968
}
938
969
0 commit comments