Skip to content

Commit 20e317c

Browse files
committed
graph : fix stream splitting when KV cache is not used
ggml-ci
1 parent 7b00429 commit 20e317c

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
744744

745745
const uint32_t n_tokens = balloc->get_n_tokens();
746746

747+
// [TAG_NO_CACHE_PAD]
748+
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
747749
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
748750

749751
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot

src/llama-graph.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,17 +1034,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10341034
const bool v_trans = v->nb[1] > v->nb[2];
10351035

10361036
// split the batch into streams if needed
1037-
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1037+
const auto n_stream = k->ne[3];
10381038

10391039
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
10401040

10411041
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
10421042
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
10431043
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
10441044

1045-
const auto n_tokens = q->ne[1];
1046-
const auto n_head = q->ne[2];
1047-
const auto n_kv = k->ne[1];
1045+
const auto n_kv = k->ne[1];
10481046

10491047
ggml_tensor * cur;
10501048

@@ -1086,8 +1084,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10861084
#endif
10871085
}
10881086

1089-
// recombine streams
1090-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
1087+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
10911088
} else {
10921089
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10931090

@@ -1133,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11331130
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
11341131

11351132
// recombine streams
1136-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
1133+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
11371134

11381135
if (!cparams.offload_kqv) {
11391136
// all nodes between the KV store and the attention output are run on the CPU
@@ -1180,6 +1177,10 @@ ggml_tensor * llm_graph_context::build_attn(
11801177

11811178
const auto & kq_mask = inp->get_kq_mask();
11821179

1180+
// [TAG_NO_CACHE_PAD]
1181+
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1182+
assert(ubatch.equal_seqs == false);
1183+
11831184
ggml_tensor * q = q_cur;
11841185
ggml_tensor * k = k_cur;
11851186
ggml_tensor * v = v_cur;

0 commit comments

Comments
 (0)