Skip to content

Commit 45ecf84

Browse files
committed
llama : add "virtual sequences"
ggml-ci
1 parent 5a35475 commit 45ecf84

12 files changed

+547
-229
lines changed

examples/parallel/parallel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
236236

237237
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
238238
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
239-
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
239+
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
240240

241241
int32_t n_total_prompt = 0;
242242
int32_t n_total_gen = 0;
@@ -290,6 +290,7 @@ int main(int argc, char ** argv) {
290290
// all sequences have ended - clear the entire KV cache
291291
for (int i = 1; i <= n_clients; ++i) {
292292
llama_memory_seq_rm(mem, i, -1, -1);
293+
293294
// but keep the system prompt
294295
llama_memory_seq_cp(mem, 0, i, -1, -1);
295296
}

src/llama-context.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ llama_context::llama_context(
3333
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

36+
const char * LLAMA_HT = getenv("LLAMA_HT");
37+
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
38+
3639
cparams.n_threads = params.n_threads;
3740
cparams.n_threads_batch = params.n_threads_batch;
3841
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -267,7 +270,7 @@ llama_context::llama_context(
267270

268271
// reserve worst-case graph
269272
if (!hparams.vocab_only && memory) {
270-
const uint32_t n_seqs = cparams.n_seq_max;
273+
const uint32_t n_seqs = 1; // reserve worst-case graph for single-sequence batches
271274
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
272275

273276
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);

src/llama-cparams.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ struct llama_cparams {
1111
uint32_t n_batch;
1212
uint32_t n_ubatch;
1313
uint32_t n_seq_max;
14-
int n_threads; // number of threads to use for generation
15-
int n_threads_batch; // number of threads to use for batch processing
14+
uint32_t n_seq_virt;
15+
int32_t n_threads; // number of threads to use for generation
16+
int32_t n_threads_batch; // number of threads to use for batch processing
1617

1718
float rope_freq_base;
1819
float rope_freq_scale;

src/llama-graph.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
982982
float kq_scale) const {
983983
const bool v_trans = v->nb[1] > v->nb[2];
984984

985+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
986+
987+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
988+
985989
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
986990
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
987991
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
@@ -1030,7 +1034,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10301034
#endif
10311035
}
10321036

1033-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1037+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
10341038
} else {
10351039
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
10361040

@@ -1075,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10751079

10761080
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
10771081

1078-
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1082+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
10791083

10801084
if (!cparams.offload_kqv) {
10811085
// all nodes between the KV store and the attention output are run on the CPU
@@ -1156,13 +1160,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
11561160
{
11571161
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
11581162

1159-
const auto n_kv = mctx_cur->get_n_kv();
1163+
const auto n_kv = mctx_cur->get_n_kv();
11601164
const auto n_tokens = ubatch.n_tokens;
1165+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
11611166

11621167
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
11631168
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
11641169

1165-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1170+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
11661171
ggml_set_input(inp->self_kq_mask);
11671172

11681173
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1362,13 +1367,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13621367

13631368
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13641369

1370+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1371+
13651372
{
13661373
const auto n_kv = mctx_cur->get_base()->get_n_kv();
13671374

13681375
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
13691376
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
13701377

1371-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1378+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
13721379
ggml_set_input(inp->self_kq_mask);
13731380

13741381
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1389,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13821389
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
13831390
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
13841391

1385-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1392+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
13861393
ggml_set_input(inp->self_kq_mask_swa);
13871394

13881395
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
255255
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
256256

257257
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
258+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
259259

260-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
260+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
261+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
262262

263263
const llama_hparams & hparams;
264264
const llama_cparams & cparams;
@@ -289,14 +289,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
289289
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
290290

291291
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
292+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
293293
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294-
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
294+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
295295

296-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
296+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
297+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
298+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
299+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
300300

301301
const llama_hparams & hparams;
302302
const llama_cparams & cparams;

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
2020
bool swa_full,
2121
uint32_t kv_size,
2222
uint32_t n_seq_max,
23+
uint32_t n_seq_virt,
2324
uint32_t n_ubatch,
24-
uint32_t n_pad) : hparams(model.hparams) {
25+
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
2526
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
2627
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
2728

2829
const uint32_t size_base = kv_size;
2930

30-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
3132

3233
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
3334
if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
4142

4243
kv_base = std::make_unique<llama_kv_cache_unified>(
4344
model, std::move(filter_base), type_k, type_v,
44-
v_trans, offload, size_base, n_seq_max, n_pad,
45+
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
4546
0, LLAMA_SWA_TYPE_NONE);
4647

4748
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
4849

4950
kv_swa = std::make_unique<llama_kv_cache_unified>(
5051
model, std::move(filter_swa), type_k, type_v,
51-
v_trans, offload, size_swa, n_seq_max, n_pad,
52+
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
5253
hparams.n_swa, hparams.swa_type);
5354
}
5455

@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100101

101102
// first try simple split
102103
do {
104+
if (n_seq_virt > 1) {
105+
// requires equal splits, so we skip the simple split
106+
break;
107+
}
108+
103109
balloc.split_reset();
104110

105111
std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140146

141147
std::vector<llama_ubatch> ubatches;
142148
while (true) {
143-
auto ubatch = balloc.split_equal(n_ubatch, false);
149+
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
144150

145151
if (ubatch.n_tokens == 0) {
146152
break;

src/llama-kv-cache-unified-iswa.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
2222
bool swa_full,
2323
uint32_t kv_size,
2424
uint32_t n_seq_max,
25+
uint32_t n_seq_virt,
2526
uint32_t n_ubatch,
2627
uint32_t n_pad);
2728

@@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
6869
private:
6970
const llama_hparams & hparams;
7071

72+
const uint32_t n_seq_virt = 1;
73+
7174
std::unique_ptr<llama_kv_cache_unified> kv_base;
7275
std::unique_ptr<llama_kv_cache_unified> kv_swa;
7376
};

0 commit comments

Comments
 (0)