@@ -937,17 +937,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
937
937
hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, ns,
938
938
ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
939
939
ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
940
- size_virt,
940
+ size_virt, // v->nb[3]
941
941
size_virt*sinfo.s0 );
942
942
}
943
943
944
944
// note: v->nb[1] > v->nb[2]
945
945
return ggml_view_4d (ctx, v,
946
946
n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , ns,
947
- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
948
- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
949
- size_virt,
950
- size_virt *sinfo.s0 );
947
+ ggml_row_size (v->type , v->ne [1 ]*n_seq_virt* hparams.n_embd_head_v ), // v->nb[1]
948
+ ggml_row_size (v->type , v->ne [1 ]*n_seq_virt ), // v->nb[2]
949
+ ggml_row_size (v-> type , v-> ne [ 1 ]), // v->nb[3]
950
+ ggml_row_size (v-> type , v-> ne [ 1 ] *sinfo.s0 ) );
951
951
}
952
952
953
953
ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
@@ -961,20 +961,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
961
961
k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
962
962
963
963
if (kv_idxs && supports_set_rows) {
964
- const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
965
-
966
- const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
967
-
968
- ggml_tensor * k_view = ggml_view_3d (ctx, k, k->ne [0 ], k->ne [1 ], ns,
969
- ggml_row_size (k->type , k->ne [0 ]),
970
- size_virt,
971
- size_virt*sinfo.s0 );
972
-
973
- k_cur = ggml_reshape_3d (ctx, k_cur, k_cur->ne [0 ], k_cur->ne [1 ]/ns, ns);
974
-
975
- kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
964
+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
976
965
977
- return ggml_set_rows (ctx, k_view , k_cur, kv_idxs);
966
+ return ggml_set_rows (ctx, k , k_cur, kv_idxs);
978
967
}
979
968
980
969
// TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -1000,45 +989,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
1000
989
v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1001
990
1002
991
if (kv_idxs && supports_set_rows) {
1003
- const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
1004
-
1005
- const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
1006
-
1007
992
if (!v_trans) {
1008
- ggml_tensor * v_view = ggml_view_3d (ctx, v, v->ne [0 ], v->ne [1 ], ns,
1009
- ggml_row_size (v->type , v->ne [0 ]),
1010
- size_virt,
1011
- size_virt*sinfo.s0 );
1012
-
1013
- v_cur = ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], v_cur->ne [1 ]/ns, ns);
1014
-
1015
- kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
993
+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1016
994
1017
- return ggml_set_rows (ctx, v_view , v_cur, kv_idxs);
995
+ return ggml_set_rows (ctx, v , v_cur, kv_idxs);
1018
996
}
1019
997
1020
998
// the row becomes a single element
1021
- ggml_tensor * v_view = ggml_view_4d (ctx, v, 1 , v->ne [1 ], v->ne [0 ], ns,
1022
- ggml_row_size (v->type , 1 ),
1023
- ggml_row_size (v->type , v->ne [1 ]),
1024
- size_virt,
1025
- size_virt*sinfo.s0 );
999
+ ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ]*v->ne [2 ], v->ne [0 ]);
1026
1000
1027
1001
// note: the V cache is transposed when not using flash attention
1028
- v_cur = ggml_permute (ctx, ggml_reshape_4d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]/ns, ns ), 2 , 0 , 1 , 3 );
1002
+ v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
1029
1003
1030
1004
// note: we can be more explicit here at the cost of extra cont
1031
1005
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032
- // v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
1033
1006
// v_cur = ggml_transpose(ctx, v_cur);
1034
- // v_cur = ggml_cont_4d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2 ]);
1007
+ // v_cur = ggml_cont_3d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
1035
1008
1036
1009
// we broadcast the KV indices n_embd_v_gqa times
1037
- // v [1, n_kv, n_embd_v_gqa, ns]
1038
- // v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039
- // kv_idxs [n_tokens/ns, 1, ns]
1040
-
1041
- kv_idxs = ggml_reshape_3d (ctx, kv_idxs, n_tokens/ns, 1 , ns);
1010
+ // v [1, n_kv*n_seq_virt, n_embd_v_gqa]
1011
+ // v_cur [1, n_tokens, n_embd_v_gqa]
1012
+ // kv_idxs [n_tokens, 1, 1]
1042
1013
1043
1014
return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
1044
1015
}
@@ -1077,8 +1048,10 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
1077
1048
int64_t * data = (int64_t *) dst->data ;
1078
1049
1079
1050
for (uint32_t s = 0 ; s < sinfo.n_seq_virt (); ++s) {
1051
+ const int64_t offs = sinfo.seq_id_virt [s]*get_size ();
1052
+
1080
1053
for (uint32_t i = 0 ; i < sinfo.size (); ++i) {
1081
- data[s*sinfo.size () + i] = sinfo.idxs [s][i];
1054
+ data[s*sinfo.size () + i] = offs + sinfo.idxs [s][i];
1082
1055
}
1083
1056
}
1084
1057
}
0 commit comments