Skip to content

Commit bf40a83

Browse files
committed
graph : fix stream splitting when KV cache is not used
ggml-ci
1 parent 9880503 commit bf40a83

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
744744

745745
const uint32_t n_tokens = balloc->get_n_tokens();
746746

747+
// [TAG_NO_CACHE_PAD]
748+
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
747749
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
748750

749751
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot

src/llama-graph.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -983,17 +983,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
983983
const bool v_trans = v->nb[1] > v->nb[2];
984984

985985
// split the batch into streams if needed
986-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
986+
const auto n_stream = k->ne[3];
987987

988988
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
989989

990990
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
991991
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
992992
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
993993

994-
const auto n_tokens = q->ne[1];
995-
const auto n_head = q->ne[2];
996-
const auto n_kv = k->ne[1];
994+
const auto n_kv = k->ne[1];
997995

998996
ggml_tensor * cur;
999997

@@ -1035,8 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10351033
#endif
10361034
}
10371035

1038-
// recombine streams
1039-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
1036+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
10401037
} else {
10411038
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10421039

@@ -1082,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10821079
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
10831080

10841081
// recombine streams
1085-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
1082+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
10861083

10871084
if (!cparams.offload_kqv) {
10881085
// all nodes between the KV store and the attention output are run on the CPU
@@ -1129,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
11291126

11301127
const auto & kq_mask = inp->get_kq_mask();
11311128

1129+
// [TAG_NO_CACHE_PAD]
1130+
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1131+
assert(ubatch.equal_seqs == false);
1132+
11321133
ggml_tensor * q = q_cur;
11331134
ggml_tensor * k = k_cur;
11341135
ggml_tensor * v = v_cur;

0 commit comments

Comments
 (0)