@@ -98,8 +98,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
98
ggml_tensor * k;
99
99
ggml_tensor * v;
100
100
101
- k = ggml_new_tensor_2d (ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
101
+ k = ggml_new_tensor_3d (ctx, type_k, n_embd_k_gqa, kv_size, 1 );
102
+ v = ggml_new_tensor_3d (ctx, type_v, n_embd_v_gqa, kv_size, 1 );
103
103
104
104
ggml_format_name (k, " cache_k_l%d" , il);
105
105
ggml_format_name (v, " cache_v_l%d" , il);
@@ -785,33 +785,40 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
785
785
786
786
auto * k = layers[ikv].k ;
787
787
788
- return ggml_view_3d (ctx, k,
789
- hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv,
788
+ const uint64_t kv_size = get_size ();
789
+
790
+ return ggml_view_4d (ctx, k,
791
+ hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv, 1 ,
790
792
ggml_row_size (k->type , hparams.n_embd_head_k ),
791
793
ggml_row_size (k->type , hparams.n_embd_k_gqa (il)),
792
- 0 );
794
+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size),
795
+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size)*0 );
793
796
}
794
797
795
798
ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t il, uint32_t n_kv) const {
796
799
const int32_t ikv = map_layer_ids.at (il);
797
800
798
801
auto * v = layers[ikv].v ;
799
802
803
+ const uint64_t kv_size = get_size ();
804
+
800
805
if (!v_trans) {
801
806
// note: v->nb[1] <= v->nb[2]
802
- return ggml_view_3d (ctx, v,
803
- hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
804
- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
805
- ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
806
- 0 );
807
+ return ggml_view_4d (ctx, v,
808
+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
809
+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
810
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
811
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size), // v->nb[3]
812
+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size)*0 );
807
813
}
808
814
809
815
// note: v->nb[1] > v->nb[2]
810
- return ggml_view_3d (ctx, v,
811
- n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v ,
812
- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
813
- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
814
- 0 );
816
+ return ggml_view_4d (ctx, v,
817
+ n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
818
+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
819
+ ggml_row_size (v->type , kv_size), // v->nb[2]
820
+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il)), // v->nb[3]
821
+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il))*0 );
815
822
}
816
823
817
824
ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -850,24 +857,16 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
850
857
851
858
if (v_idxs && supports_set_rows) {
852
859
if (!v_trans) {
860
+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
861
+
853
862
return ggml_set_rows (ctx, v, v_cur, v_idxs);
854
863
}
855
864
856
865
// the row becomes a single element
857
- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
858
-
859
- // note: the V cache is transposed when not using flash attention
860
- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
866
+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
861
867
862
- // note: we can be more explicit here at the cost of extra cont
863
- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
- // v_cur = ggml_transpose(ctx, v_cur);
865
- // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
868
+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur->ne [0 ]*v_cur->ne [1 ]);
866
869
867
- // we broadcast the KV indices n_embd_v_gqa times
868
- // v [1, n_kv, n_embd_v_gqa]
869
- // v_cur [1, n_tokens, n_embd_v_gqa]
870
- // v_idxs [n_tokens, 1, 1]
871
870
return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
872
871
}
873
872
@@ -904,7 +903,14 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904
903
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
905
904
const uint32_t n_tokens = ubatch.n_tokens ;
906
905
907
- ggml_tensor * v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
906
+ ggml_tensor * v_idxs;
907
+
908
+ if (!v_trans) {
909
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
910
+ } else {
911
+ // TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
912
+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
913
+ }
908
914
909
915
ggml_set_input (v_idxs);
910
916
@@ -921,7 +927,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
921
927
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
922
928
int64_t * data = (int64_t *) dst->data ;
923
929
924
- for (int64_t i = 0 ; i < n_tokens; ++i) {
930
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
925
931
data[i] = sinfo.idxs .at (i);
926
932
}
927
933
}
@@ -936,8 +942,22 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
936
942
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
937
943
int64_t * data = (int64_t *) dst->data ;
938
944
939
- for (int64_t i = 0 ; i < n_tokens; ++i) {
940
- data[i] = sinfo.idxs .at (i);
945
+ if (!v_trans) {
946
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
947
+ data[i] = sinfo.idxs .at (i);
948
+ }
949
+ } else {
950
+ // note: the V cache is transposed when not using flash attention
951
+ const int64_t kv_size = get_size ();
952
+
953
+ // TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
954
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa ();
955
+
956
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
957
+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
958
+ data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs .at (i);
959
+ }
960
+ }
941
961
}
942
962
}
943
963
0 commit comments