Skip to content

Commit 26439ad

Browse files
ggerganovMinh141120
authored andcommitted
ubatch : new splitting logic (ggml-org#14217)
ggml-ci
1 parent b85c660 commit 26439ad

19 files changed

+514
-732
lines changed

src/llama-batch.cpp

Lines changed: 207 additions & 341 deletions
Large diffs are not rendered by default.

src/llama-batch.h

Lines changed: 92 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
#include <array>
88
#include <vector>
99
#include <set>
10+
#include <bitset>
11+
#include <unordered_map>
1012

13+
// keep this struct lightweight
14+
// it points to data in `llama_batch_allocr`
1115
// keep this struct lightweight
1216
// it points to data in `llama_batch_allocr`
1317
struct llama_ubatch {
@@ -19,105 +23,127 @@ struct llama_ubatch {
1923
uint32_t n_seqs; // sequence sets in the ubatch
2024
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
2125

22-
llama_token * token; // [n_tokens]
23-
float * embd; // [n_embd, n_tokens]
24-
llama_pos * pos; // [n_tokens]
25-
int32_t * n_seq_id; // [n_seqs]
26-
llama_seq_id ** seq_id; // [n_seqs]
27-
int8_t * output; // [n_tokens]
28-
};
29-
30-
struct llama_sbatch_seq {
31-
int32_t n_seq_id;
32-
33-
llama_seq_id * seq_id;
34-
35-
size_t offset;
36-
size_t length;
37-
};
38-
39-
// sequence-length-aware batch splitting
40-
struct llama_sbatch {
41-
// tokens left in this batch
42-
size_t n_tokens;
43-
44-
// only for debugging purposes
45-
const llama_vocab * vocab;
46-
47-
// sorted indices into the batch
48-
std::vector<int64_t> ids;
49-
// batch indices of the output
50-
std::vector<int64_t> out_ids;
51-
std::vector<llama_sbatch_seq> seq;
52-
53-
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
54-
55-
// buffers for the ubatches
56-
// TODO: very hacky, this needs a complete rework
57-
struct ubatch_data {
58-
std::vector<llama_token> token;
59-
std::vector<float> embd;
60-
std::vector<llama_pos> pos;
61-
std::vector<int32_t> n_seq_id;
62-
std::vector<llama_seq_id *> seq_id;
63-
std::vector<int8_t> output;
64-
};
65-
66-
std::vector<ubatch_data> udatas;
67-
68-
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
69-
70-
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
71-
72-
// simple split, unknown number of sequences of unequal lengths
73-
llama_ubatch split_simple(size_t n_ubatch);
74-
75-
// make batches of equal-length sequences
76-
llama_ubatch split_equal(size_t n_ubatch);
77-
78-
// sequence-wise split
79-
llama_ubatch split_seq(size_t n_ubatch);
80-
81-
llama_sbatch() = default;
82-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
26+
// seq_id_unq: unique sequence ids in the ubatch
27+
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
28+
// used for extracting sequence pooled embeddings
29+
30+
// // size | idx | val
31+
llama_token * token; // [n_tokens] | i | id, token
32+
float * embd; // [n_embd, n_tokens] | i | embd
33+
llama_pos * pos; // [n_tokens] | i | pos
34+
int32_t * n_seq_id; // [n_tokens] | i | -
35+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
36+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
37+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
38+
int8_t * output; // [n_tokens] | i | -
8339
};
8440

85-
// a helper for sanitizing and fulfilling a batch
41+
// a helper for sanitizing, fulfilling and splitting a batch
8642
class llama_batch_allocr {
8743
public:
88-
llama_batch_allocr();
44+
llama_batch_allocr(uint32_t n_pos_per_embd);
8945

9046
// sanitize and auto-gen missing data in the input batch
9147
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
9248
bool init(
9349
const llama_batch & batch_inp,
9450
const llama_vocab & vocab,
9551
const llama_memory_i * memory,
96-
bool embd_all);
52+
uint32_t n_embd,
53+
bool output_all);
9754

9855
const llama_batch & get_batch() const;
9956

57+
uint32_t get_n_tokens() const;
10058
uint32_t get_n_outputs() const;
10159

60+
// the array of output indices in the order they were encountered during the ubatch splitting
61+
std::vector<int32_t> & get_out_ids();
62+
63+
// min/max positions of each sequence in the current ubatch
10264
llama_pos seq_pos_min(llama_seq_id seq_id) const;
10365
llama_pos seq_pos_max(llama_seq_id seq_id) const;
10466

67+
// call once before splitting the batch to reset the internal state
68+
void split_reset();
69+
70+
// simple split, unknown number of sequence sets of unequal lengths
71+
llama_ubatch split_simple(uint32_t n_ubatch);
72+
73+
// make ubatches of equal-length sequences sets
74+
llama_ubatch split_equal(uint32_t n_ubatch);
75+
76+
// sequence-set-wise split - each ubatch contains a single sequence-set
77+
llama_ubatch split_seq(uint32_t n_ubatch);
78+
79+
// a helper method for creating a well-defined ubatch of tokens
80+
// TODO: support embeddings if needed in the future
81+
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
82+
10583
private:
10684
void clear();
10785

86+
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
87+
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
88+
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
89+
90+
// for debugging, start with LLAMA_BATCH_DEBUG=2
91+
void ubatch_print(const llama_ubatch & ubatch, int debug);
92+
10893
llama_batch batch;
10994

95+
// only for debugging purposes
96+
const llama_vocab * vocab;
97+
98+
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
99+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
100+
const uint32_t n_pos_per_embd;
101+
102+
uint32_t n_embd;
110103
uint32_t n_outputs;
111104

112105
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
113106

114107
std::vector<llama_pos> pos;
115108
std::vector<int32_t> n_seq_id;
116109
std::vector<llama_seq_id *> seq_id;
110+
std::vector<llama_seq_id> seq_id_unq;
111+
std::vector<int32_t> seq_idx;
117112
std::vector<int8_t> output;
118113

119-
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
120-
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
114+
using pos_set_t = std::set<llama_pos>;
115+
using seq_cpl_t = std::vector<bool>;
116+
117+
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
118+
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
119+
120+
using idx_vec_t = std::vector<int32_t>;
121+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
122+
123+
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
124+
125+
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
126+
127+
// batch indices of the output
128+
std::vector<int32_t> out_ids;
129+
130+
// used[i] indicates if token i has already been used in a previous ubatch
131+
std::vector<bool> used;
132+
133+
// llama_ubatch points to this data:
134+
struct ubatch {
135+
std::vector<llama_token> token;
136+
std::vector<float> embd;
137+
std::vector<llama_pos> pos;
138+
std::vector<int32_t> n_seq_id;
139+
std::vector<llama_seq_id *> seq_id;
140+
std::vector<llama_seq_id> seq_id_unq;
141+
std::vector<int32_t> seq_idx;
142+
std::vector<int8_t> output;
143+
};
144+
145+
// current splitting state:
146+
std::vector<ubatch> ubatches;
121147

122148
int debug;
123149
};

src/llama-context.cpp

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ llama_context::llama_context(
2020
const llama_model & model,
2121
llama_context_params params) :
2222
model(model),
23-
batch_allocr(std::make_unique<llama_batch_allocr>()) {
23+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
2424
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2525

2626
t_start_us = model.t_start_us;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
722722
}
723723

724724
int llama_context::encode(const llama_batch & batch_inp) {
725+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726+
725727
if (batch_inp.n_tokens == 0) {
726728
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
727729
return -1;
728730
}
729731

732+
const auto & hparams = model.hparams;
733+
734+
const int64_t n_embd = hparams.n_embd;
735+
730736
// note: during encode, we always pass the full sequence starting from pos = 0
731-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
737+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
732738
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733739
return -1;
734740
}
735741

736-
const llama_batch & batch = batch_allocr->get_batch();
742+
const uint32_t n_tokens = balloc->get_n_tokens();
737743

738-
const uint32_t n_tokens = batch.n_tokens;
739-
740-
const int64_t n_embd = hparams.n_embd;
744+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
741745

742746
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
743747
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
751755

752756
n_queued_tokens += n_tokens;
753757

754-
const auto & hparams = model.hparams;
755-
756-
const int64_t n_embd = hparams.n_embd;
757-
758-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
759-
760-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
761-
762758
// reserve output buffer
763759
if (output_reserve(n_tokens) < n_tokens) {
764760
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -822,12 +818,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
822818
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
823819
const int32_t seq_idx = ubatch.seq_idx[seq_id];
824820

825-
// TODO: fix indexing [UBATCH_IDX]
826-
for (uint32_t i = 0; i < n_tokens; i++) {
827-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
828-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
829-
continue;
830-
}
831821
embd_seq_out[seq_id].resize(n_embd);
832822
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
833823
}
@@ -839,12 +829,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
839829

840830
const uint32_t n_cls_out = hparams.n_cls_out;
841831

842-
// TODO: fix indexing [UBATCH_IDX]
843-
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
844-
const llama_seq_id seq_id = ubatch.seq_id[s][0];
845-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
846-
continue;
847-
}
832+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
835+
848836
embd_seq_out[seq_id].resize(n_cls_out);
849837
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
850838
}
@@ -877,8 +865,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
877865
cross.seq_ids_enc.resize(n_tokens);
878866
for (uint32_t i = 0; i < n_tokens; i++) {
879867
cross.seq_ids_enc[i].clear();
868+
880869
for (int s = 0; s < batch.n_seq_id[i]; s++) {
881-
llama_seq_id seq_id = batch.seq_id[i][s];
870+
const llama_seq_id seq_id = batch.seq_id[i][s];
871+
882872
cross.seq_ids_enc[i].insert(seq_id);
883873
}
884874
}
@@ -888,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
888878
}
889879

890880
int llama_context::decode(const llama_batch & batch_inp) {
881+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882+
891883
if (!memory) {
892884
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
893885
return encode(batch_inp);
@@ -899,29 +891,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
899891
return -1;
900892
}
901893

902-
// when computing embeddings, all tokens are output
903-
const bool embd_all = cparams.embeddings;
904-
905-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
906-
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
907-
return -1;
908-
}
909-
910-
const llama_batch & batch = batch_allocr->get_batch();
911-
912894
const auto & vocab = model.vocab;
913895
const auto & hparams = model.hparams;
914896

915897
const int32_t n_vocab = vocab.n_tokens();
916898
const int64_t n_embd = hparams.n_embd;
917899

918-
const uint32_t n_tokens_all = batch.n_tokens;
900+
// when computing embeddings, all tokens are output
901+
const bool output_all = cparams.embeddings;
919902

920-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
903+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
904+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
905+
return -1;
906+
}
921907

922-
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
908+
const uint32_t n_tokens_all = balloc->get_n_tokens();
909+
const uint32_t n_outputs_all = balloc->get_n_outputs();
923910

924-
if (embd_all) {
911+
if (output_all) {
925912
// require that all tokens are output
926913
if (n_outputs_all != n_tokens_all) {
927914
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -950,7 +937,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
950937
llama_memory_context_ptr mctx;
951938

952939
while (true) {
953-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
940+
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
954941
if (!mstate) {
955942
return -2;
956943
}
@@ -2056,7 +2043,7 @@ void llama_context::opt_epoch_iter(
20562043

20572044
uint32_t n_outputs_all = n_tokens_all;
20582045

2059-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2046+
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
20602047
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20612048
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20622049
break;

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ struct llama_context {
247247
std::map<llama_seq_id, std::vector<float>> embd_seq;
248248

249249
// reuse the batch_allocr to avoid unnecessary memory allocations
250-
std::unique_ptr<llama_batch_allocr> batch_allocr;
250+
std::unique_ptr<llama_batch_allocr> balloc;
251251

252252
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253253

0 commit comments

Comments
 (0)