Skip to content

Commit 01612b7

Browse files
authored
llama : reuse compute graphs (#14482)
* llama : reuse compute graphs ggml-ci * llama-bench : add graph reuse parameter ggml-ci * cont : remove the parameter and the sched resets ggml-ci * graph : rename update() to can_reuse() ggml-ci * params : remove is_same() ggml-ci * graph : set res->params in llm_graph_context constructor ggml-ci * graph : avoid set_max_nodes in llm_graph_result ggml-ci * kv-cache : reuse llama_context's graph result instance ggml-ci * context : reset the previous graph result upon memory updates ggml-ci * batch : llama_ubatch now carries its data instead of pointing to balloc ggml-ci * merge : fix build ggml-ci * graph : fix can_reuse() checks when flash-attention is disabled * graph : move llm_graph_result impl in source file + debug env ggml-ci
1 parent 086cf81 commit 01612b7

12 files changed

+542
-283
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,7 @@ extern "C" {
13941394

13951395
int32_t n_p_eval;
13961396
int32_t n_eval;
1397+
int32_t n_reused; // number of times a ggml compute graph had been reused
13971398
};
13981399

13991400
struct llama_perf_sampler_data {

src/llama-batch.cpp

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ bool llama_batch_allocr::init(
210210
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211211

212212
llama_ubatch ubatch {
213-
/*.equal_seqs =*/ false,
213+
/*.b_equal_seqs =*/ false,
214214
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
215215
/*.n_seq_tokens =*/ (uint32_t) 1,
216216
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
@@ -223,6 +223,7 @@ bool llama_batch_allocr::init(
223223
/*.seq_id_unq =*/ this->seq_id_unq.data(),
224224
/*.seq_idx =*/ this->seq_idx.data(),
225225
/*.output =*/ batch.logits,
226+
/*.data =*/ {},
226227
};
227228

228229
ubatch_print(ubatch, debug);
@@ -366,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
366367
clear();
367368
split_reset();
368369

369-
ubatches.emplace_back();
370+
auto udata = std::make_shared<llama_ubatch::data_t>();
370371

371-
auto & ubatch = ubatches.back();
372-
373-
ubatch.token .resize(n_tokens);
374-
ubatch.embd .clear();
375-
ubatch.pos .resize(n_tokens);
376-
ubatch.n_seq_id .resize(n_tokens);
377-
ubatch.seq_id .resize(n_tokens);
378-
ubatch.seq_id_unq.resize(0);
379-
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
380-
ubatch.output .resize(n_tokens);
372+
udata->token .resize(n_tokens);
373+
udata->embd .clear();
374+
udata->pos .resize(n_tokens);
375+
udata->n_seq_id .resize(n_tokens);
376+
udata->seq_id .resize(n_tokens);
377+
udata->seq_id_unq.resize(0);
378+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379+
udata->output .resize(n_tokens);
381380

382381
for (uint32_t s = 0; s < n_seqs; ++s) {
383-
ubatch.seq_idx[s] = s;
384-
ubatch.seq_id_unq.push_back(s);
382+
udata->seq_idx[s] = s;
383+
udata->seq_id_unq.push_back(s);
385384
}
386385

387386
llama_ubatch res {
388-
/*.equal_seqs =*/ true,
387+
/*.b_equal_seqs =*/ true,
389388
/*.n_tokens =*/ n_tokens,
390389
/*.n_seq_tokens =*/ n_seq_tokens,
391390
/*.n_seqs =*/ n_seqs,
392391
/*.n_seqs_unq =*/ n_seqs,
393392

394-
/*.token =*/ ubatch.token.data(),
393+
/*.token =*/ udata->token.data(),
395394
/*.embd =*/ nullptr,
396-
/*.pos =*/ ubatch.pos.data(),
397-
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
398-
/*.seq_id =*/ ubatch.seq_id.data(),
399-
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
400-
/*.seq_idx =*/ ubatch.seq_idx.data(),
401-
/*.output =*/ ubatch.output.data(),
395+
/*.pos =*/ udata->pos.data(),
396+
/*.n_seq_id =*/ udata->n_seq_id.data(),
397+
/*.seq_id =*/ udata->seq_id.data(),
398+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
399+
/*.seq_idx =*/ udata->seq_idx.data(),
400+
/*.output =*/ udata->output.data(),
401+
/*.data =*/ std::move(udata),
402402
};
403403

404404
return res;
@@ -439,8 +439,6 @@ void llama_batch_allocr::split_reset() {
439439

440440
used.clear();
441441
used.resize(get_n_tokens(), false);
442-
443-
ubatches.clear();
444442
}
445443

446444
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -655,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655653

656654
assert(n_tokens%n_seqs == 0);
657655

658-
ubatches.emplace_back();
659-
660-
auto & ubatch = ubatches.back();
656+
auto udata = std::make_shared<llama_ubatch::data_t>();
661657

662658
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
663659

664660
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
665661
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
666662

667-
ubatch.token .resize(n_tokens);
668-
ubatch.embd .resize(n_embd_all);
669-
ubatch.pos .resize(n_pos_all);
670-
ubatch.n_seq_id .resize(n_tokens);
671-
ubatch.seq_id .resize(n_tokens);
672-
ubatch.seq_id_unq.resize(0);
673-
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
674-
ubatch.output .resize(n_tokens);
663+
udata->token .resize(n_tokens);
664+
udata->embd .resize(n_embd_all);
665+
udata->pos .resize(n_pos_all);
666+
udata->n_seq_id .resize(n_tokens);
667+
udata->seq_id .resize(n_tokens);
668+
udata->seq_id_unq.resize(0);
669+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670+
udata->output .resize(n_tokens);
675671

676672
seq_set_t seq_set_unq;
677673

678674
for (size_t i = 0; i < idxs.size(); ++i) {
679675
if (batch.token) {
680-
ubatch.token[i] = batch.token[idxs[i]];
676+
udata->token[i] = batch.token[idxs[i]];
681677
}
682678

683679
if (batch.embd) {
684-
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
680+
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
685681
}
686682

687683
for (int j = 0; j < n_pos_cur; ++j) {
688-
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
684+
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
689685
}
690686

691-
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
692-
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
693-
ubatch.output[i] = batch.logits[idxs[i]];
687+
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688+
udata->seq_id[i] = batch.seq_id[idxs[i]];
689+
udata->output[i] = batch.logits[idxs[i]];
694690

695-
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
696-
seq_set_unq.set(ubatch.seq_id[i][s]);
691+
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692+
seq_set_unq.set(udata->seq_id[i][s]);
697693
}
698694

699-
if (ubatch.output[i]) {
695+
if (udata->output[i]) {
700696
out_ids.push_back(idxs[i]);
701697
}
702698
}
703699

704700
for (uint32_t s = 0; s < n_seq_max; ++s) {
705701
if (seq_set_unq.test(s)) {
706-
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
707-
ubatch.seq_id_unq.push_back(s);
702+
udata->seq_idx[s] = udata->seq_id_unq.size();
703+
udata->seq_id_unq.push_back(s);
708704
}
709705
}
710706

711707
llama_ubatch res {
712-
/*.equal_seqs =*/ equal_seqs,
708+
/*.b_equal_seqs =*/ equal_seqs,
713709
/*.n_tokens =*/ n_tokens,
714710
/*.n_seq_tokens =*/ n_tokens/n_seqs,
715711
/*.n_seqs =*/ n_seqs,
716-
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
717-
718-
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
719-
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
720-
/*.pos =*/ ubatch.pos.data(),
721-
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
722-
/*.seq_id =*/ ubatch.seq_id.data(),
723-
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
724-
/*.seq_idx =*/ ubatch.seq_idx.data(),
725-
/*.output =*/ ubatch.output.data(),
712+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713+
714+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
715+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716+
/*.pos =*/ udata->pos.data(),
717+
/*.n_seq_id =*/ udata->n_seq_id.data(),
718+
/*.seq_id =*/ udata->seq_id.data(),
719+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
720+
/*.seq_idx =*/ udata->seq_idx.data(),
721+
/*.output =*/ udata->output.data(),
722+
/*.data =*/ std::move(udata),
726723
};
727724

728725
if (debug > 0) {
729-
LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
726+
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
730727

731728
ubatch_print(res, debug);
732729
}
@@ -736,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
736733

737734
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
738735
if (debug > 0) {
739-
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
736+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
740737
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
741738
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
742739
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);

src/llama-batch.h

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
#include <vector>
99
#include <set>
1010
#include <bitset>
11+
#include <memory>
1112
#include <unordered_map>
1213

1314
// keep this struct lightweight
14-
// it points to data in `llama_batch_allocr`
1515
struct llama_ubatch {
16-
bool equal_seqs;
16+
bool equal_seqs() const {
17+
return b_equal_seqs != 0;
18+
}
19+
20+
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
21+
// otherwise address sanitizer complains
1722
// TODO: whole_seqs for embeddings?
1823

1924
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
3439
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
3540
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
3641
int8_t * output; // [n_tokens] | i | -
42+
43+
struct data_t {
44+
std::vector<llama_token> token;
45+
std::vector<float> embd;
46+
std::vector<llama_pos> pos;
47+
std::vector<int32_t> n_seq_id;
48+
std::vector<llama_seq_id *> seq_id;
49+
std::vector<llama_seq_id> seq_id_unq;
50+
std::vector<int32_t> seq_idx;
51+
std::vector<int8_t> output;
52+
};
53+
54+
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
55+
std::shared_ptr<data_t> data;
3756
};
3857

3958
// a helper for sanitizing, fulfilling and splitting a batch
@@ -137,20 +156,5 @@ class llama_batch_allocr {
137156
// used[i] indicates if token i has already been used in a previous ubatch
138157
std::vector<bool> used;
139158

140-
// llama_ubatch points to this data:
141-
struct ubatch {
142-
std::vector<llama_token> token;
143-
std::vector<float> embd;
144-
std::vector<llama_pos> pos;
145-
std::vector<int32_t> n_seq_id;
146-
std::vector<llama_seq_id *> seq_id;
147-
std::vector<llama_seq_id> seq_id_unq;
148-
std::vector<int32_t> seq_idx;
149-
std::vector<int8_t> output;
150-
};
151-
152-
// current splitting state:
153-
std::vector<ubatch> ubatches;
154-
155159
int debug;
156160
};

0 commit comments

Comments
 (0)