Skip to content

Commit 7b00429

Browse files
committed
llama : use "stream" vs "virtual sequence"
ggml-ci
1 parent 38479e2 commit 7b00429

11 files changed

+197
-177
lines changed

examples/parallel/parallel.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
236236

237237
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
238238
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
239-
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
239+
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
240240

241241
int32_t n_total_prompt = 0;
242242
int32_t n_total_gen = 0;
@@ -290,7 +290,6 @@ int main(int argc, char ** argv) {
290290
// all sequences have ended - clear the entire KV cache
291291
for (int i = 1; i <= n_clients; ++i) {
292292
llama_memory_seq_rm(mem, i, -1, -1);
293-
294293
// but keep the system prompt
295294
llama_memory_seq_cp(mem, 0, i, -1, -1);
296295
}

src/llama-context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ llama_context::llama_context(
3434
}
3535

3636
const char * LLAMA_HT = getenv("LLAMA_HT");
37-
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
37+
cparams.kv_unified = (LLAMA_HT && atoi(LLAMA_HT) > 0) ? false : true;
3838

3939
cparams.n_threads = params.n_threads;
4040
cparams.n_threads_batch = params.n_threads_batch;
@@ -270,7 +270,7 @@ llama_context::llama_context(
270270

271271
// reserve worst-case graph
272272
if (!hparams.vocab_only && memory) {
273-
const uint32_t n_seqs = 1; // reserve worst-case graph for single-sequence batches
273+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
274274
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
275275

276276
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -303,7 +303,7 @@ llama_context::llama_context(
303303

304304
// reserve with tg graph to get the number of splits and nodes
305305
{
306-
auto * gf = graph_reserve(1, 1, 1, mctx.get());
306+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
307307
if (!gf) {
308308
throw std::runtime_error("failed to allocate compute tg buffers");
309309
}

src/llama-cparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ struct llama_cparams {
1111
uint32_t n_batch;
1212
uint32_t n_ubatch;
1313
uint32_t n_seq_max;
14-
uint32_t n_seq_virt;
1514
int32_t n_threads; // number of threads to use for generation
1615
int32_t n_threads_batch; // number of threads to use for batch processing
1716

@@ -34,6 +33,7 @@ struct llama_cparams {
3433
bool no_perf;
3534
bool warmup;
3635
bool op_offload;
36+
bool kv_unified;
3737

3838
enum llama_pooling_type pooling_type;
3939

src/llama-graph.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10001000
{
10011001
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10021002

1003-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004-
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1003+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
10051005

10061006
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
10071007
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
10081008

1009-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
1009+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
10101010
ggml_set_input(inp->self_kq_mask);
10111011

10121012
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1033,9 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10331033
float kq_scale) const {
10341034
const bool v_trans = v->nb[1] > v->nb[2];
10351035

1036-
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1036+
// split the batch into streams if needed
1037+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
10371038

1038-
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
1039+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
10391040

10401041
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
10411042
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
@@ -1085,7 +1086,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10851086
#endif
10861087
}
10871088

1088-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
1089+
// recombine streams
1090+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
10891091
} else {
10901092
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10911093

@@ -1130,7 +1132,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11301132

11311133
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
11321134

1133-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
1135+
// recombine streams
1136+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
11341137

11351138
if (!cparams.offload_kqv) {
11361139
// all nodes between the KV store and the attention output are run on the CPU
@@ -1207,13 +1210,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12071210
{
12081211
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12091212

1210-
const auto n_kv = mctx_cur->get_n_kv();
1211-
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1213+
const auto n_kv = mctx_cur->get_n_kv();
1214+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
12121215

12131216
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
12141217
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12151218

1216-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
1219+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
12171220
ggml_set_input(inp->self_kq_mask);
12181221

12191222
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1455,15 +1458,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14551458

14561459
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14571460

1458-
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1461+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
14591462

14601463
{
14611464
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14621465

14631466
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
14641467
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14651468

1466-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
1469+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
14671470
ggml_set_input(inp->self_kq_mask);
14681471

14691472
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1477,7 +1480,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14771480
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
14781481
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14791482

1480-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
1483+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
14811484
ggml_set_input(inp->self_kq_mask_swa);
14821485

14831486
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257257
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258258
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
259259

260-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
261-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
260+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
261+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
262262

263263
const llama_hparams & hparams;
264264
const llama_cparams & cparams;
@@ -293,10 +293,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
293293
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294294
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
295295

296-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
297-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
298-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
299-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
296+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
297+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
298+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
299+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
300300

301301
const llama_hparams & hparams;
302302
const llama_cparams & cparams;
@@ -343,8 +343,8 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
343343
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
344344
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
345345

346-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
347-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
346+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
347+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
348348

349349
const llama_hparams & hparams;
350350
const llama_cparams & cparams;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1818
bool v_trans,
1919
bool offload,
2020
bool swa_full,
21+
bool unified,
2122
uint32_t kv_size,
2223
uint32_t n_seq_max,
23-
uint32_t n_seq_virt,
2424
uint32_t n_ubatch,
25-
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
25+
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
2626
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
2727
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
2828

2929
const uint32_t size_base = kv_size;
3030

31-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
31+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
3232

3333
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
3434
if (swa_full) {
@@ -42,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
4242

4343
kv_base = std::make_unique<llama_kv_cache_unified>(
4444
model, std::move(filter_base), type_k, type_v,
45-
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
45+
v_trans, offload, unified, size_base, n_seq_max, n_pad,
4646
0, LLAMA_SWA_TYPE_NONE);
4747

4848
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
4949

5050
kv_swa = std::make_unique<llama_kv_cache_unified>(
5151
model, std::move(filter_swa), type_k, type_v,
52-
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
52+
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
5353
hparams.n_swa, hparams.swa_type);
5454
}
5555

@@ -101,7 +101,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
101101

102102
// first try simple split
103103
do {
104-
if (n_seq_virt > 1) {
104+
if (!unified) {
105105
// requires equal splits, so we skip the simple split
106106
break;
107107
}
@@ -146,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
146146

147147
std::vector<llama_ubatch> ubatches;
148148
while (true) {
149-
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
149+
auto ubatch = balloc.split_equal(n_ubatch, !unified);
150150

151151
if (ubatch.n_tokens == 0) {
152152
break;

src/llama-kv-cache-unified-iswa.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
2020
bool v_trans,
2121
bool offload,
2222
bool swa_full,
23+
bool unified,
2324
uint32_t kv_size,
2425
uint32_t n_seq_max,
25-
uint32_t n_seq_virt,
2626
uint32_t n_ubatch,
2727
uint32_t n_pad);
2828

@@ -69,7 +69,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
6969
private:
7070
const llama_hparams & hparams;
7171

72-
const uint32_t n_seq_virt = 1;
72+
const bool unified;
7373

7474
std::unique_ptr<llama_kv_cache_unified> kv_base;
7575
std::unique_ptr<llama_kv_cache_unified> kv_swa;

0 commit comments

Comments
 (0)