Skip to content

Commit 3724d37

Browse files
ggerganovMinh141120
authored andcommitted
batch : remove logits_all flag (ggml-org#14141)
ggml-ci
1 parent 16faf81 commit 3724d37

10 files changed

+111
-252
lines changed

src/llama-batch.cpp

Lines changed: 48 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -130,42 +130,20 @@ bool llama_batch_allocr::init(
130130
warn = true;
131131
}
132132
}
133-
134-
if (warn) {
135-
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
136-
137-
output.resize(batch.n_tokens, true);
138-
batch.logits = output.data();
139-
}
140-
}
141-
142-
//
143-
// compute stats
144-
//
145-
146-
this->n_embd = n_embd;
147-
148-
// count the outputs in this batch
149-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
150-
n_outputs += batch.logits[i] != 0;
151133
}
152-
153-
// determine coupled sequences
154-
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
156-
const llama_seq_id s0 = batch.seq_id[i][0];
157-
158-
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
159-
const llama_seq_id s1 = batch.seq_id[i][s];
160-
161-
seq_pos[s1].insert(batch.pos[i]);
162-
163-
if (s > 0) {
164-
// mark that sequence s1 is coupled to s0
165-
seq_cpl[s1][s0] = true;
166-
167-
// note: tracking the other way around is not necessary for now
168-
//seq_cpl[s0][s1] = true;
134+
if (batch->logits) {
135+
if (ubatch.equal_seqs) {
136+
for (size_t i = 0; i < length; ++i) {
137+
size_t id = ids[seq.offset + i];
138+
int8_t is_output = batch->logits[id];
139+
ubatch.output[ubatch.n_tokens + i] = is_output;
140+
if (is_output) { out_ids.push_back(id); }
141+
}
142+
} else {
143+
// simple split
144+
ubatch.output = batch->logits + seq.offset;
145+
for (size_t i = 0; i < length; ++i) {
146+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
169147
}
170148
}
171149
}
@@ -281,141 +259,49 @@ bool llama_batch_allocr::init(
281259
}
282260
}
283261

284-
if (memory) {
285-
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
286-
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
287-
if (seq_cpl[s0][s1]) {
288-
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
289-
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
290-
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
291-
return false;
292-
}
293-
}
262+
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
263+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
264+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
265+
if (!seq.empty()) {
266+
size_t length = 0;
267+
size_t n_tokens_in_ubatch = 0;
268+
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
269+
// smallest first, because it's easier to split this way;
270+
// starting from the end to pop in constant time.
271+
for (size_t i = seq.size(); i-- > 0;) {
272+
llama_sbatch_seq & s = seq[i];
273+
GGML_ASSERT(s.length > 0);
274+
if (length == 0) {
275+
length = s.length < n_ubatch ? s.length : n_ubatch;
294276
}
277+
add_seq_to_ubatch(ubatch, s, length);
278+
n_tokens_in_ubatch += length;
279+
// shared prompts can't be mixed with any of their sequences,
280+
// so it's safer to compute them in their own ubatch
281+
if (s.n_seq_id > 1) { break; }
282+
// stop when there isn't enough space for another sequence
283+
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
295284
}
296285
}
297-
298-
// disallow partial sequence sub-sets:
299-
//
300-
// invalid: x
301-
// i: 0 1 2 ...
302-
// ---------------------------------------
303-
// seq_id[i][0]: 0 0 1
304-
// seq_id[i][1]: 1 1 2
305-
// seq_id[i][2]: 2
306-
//
307-
// disallow decreasing sequence positions:
308-
//
309-
// invalid: x
310-
// i: 0 1 2 3 4 5 6 ...
311-
// ---------------------------------------
312-
// pos[i]: 4 5 0 1 6 2 3
313-
// seq_id[i][0]: 0 0 1 1 0 1 0
314-
//
315-
{
316-
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
318-
cur_seq_set[s].set();
319-
}
320-
321-
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322-
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
323-
cur_seq_pos[s] = -1;
324-
}
325-
326-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
327-
const llama_pos pos = batch.pos[i];
328-
329-
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
330-
const llama_seq_id seq_id = batch.seq_id[i][s];
331-
332-
cur_seq_set[seq_id] &= seq_set[i];
333-
334-
if (cur_seq_set[seq_id].none()) {
335-
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
336-
return false;
337-
}
338-
339-
if (pos < cur_seq_pos[seq_id]) {
340-
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
341-
return false;
342-
}
343-
}
344-
}
345-
}
346-
347-
split_reset();
348-
349-
return true;
286+
return ubatch;
350287
}
351288

352-
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
353-
const uint32_t n_tokens = n_seq_tokens*n_seqs;
354-
355-
clear();
356-
split_reset();
357-
358-
ubatches.emplace_back();
359-
360-
auto & ubatch = ubatches.back();
361-
362-
ubatch.token .resize(n_tokens);
363-
ubatch.embd .clear();
364-
ubatch.pos .resize(n_tokens);
365-
ubatch.n_seq_id .resize(n_tokens);
366-
ubatch.seq_id .resize(n_tokens);
367-
ubatch.seq_id_unq.resize(0);
368-
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
369-
ubatch.output .resize(n_tokens);
370-
371-
for (uint32_t s = 0; s < n_seqs; ++s) {
372-
ubatch.seq_idx[s] = s;
373-
ubatch.seq_id_unq.push_back(s);
289+
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
290+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
291+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
292+
if (!seq.empty()) {
293+
llama_sbatch_seq & s = seq[seq.size() - 1];
294+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
295+
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
296+
add_seq_to_ubatch(ubatch, s, length);
374297
}
375-
376-
llama_ubatch res {
377-
/*.equal_seqs =*/ true,
378-
/*.n_tokens =*/ n_tokens,
379-
/*.n_seq_tokens =*/ n_seq_tokens,
380-
/*.n_seqs =*/ n_seqs,
381-
/*.n_seqs_unq =*/ n_seqs,
382-
383-
/*.token =*/ ubatch.token.data(),
384-
/*.embd =*/ nullptr,
385-
/*.pos =*/ ubatch.pos.data(),
386-
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
387-
/*.seq_id =*/ ubatch.seq_id.data(),
388-
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
389-
/*.seq_idx =*/ ubatch.seq_idx.data(),
390-
/*.output =*/ ubatch.output.data(),
391-
};
392-
393-
return res;
298+
return ubatch;
394299
}
395300

396-
const llama_batch & llama_batch_allocr::get_batch() const {
397-
return batch;
398-
}
399-
400-
uint32_t llama_batch_allocr::get_n_tokens() const {
401-
return batch.n_tokens;
402-
}
403-
404-
uint32_t llama_batch_allocr::get_n_outputs() const {
405-
return n_outputs;
406-
}
407-
408-
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409-
return out_ids;
410-
}
411-
412-
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
413-
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
414-
}
415-
416-
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
417-
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
418-
}
301+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
302+
GGML_ASSERT(batch.n_tokens >= 0);
303+
this->batch = &batch;
304+
this->n_embd = n_embd;
419305

420306
void llama_batch_allocr::split_reset() {
421307
out_ids.clear();

src/llama-batch.h

Lines changed: 38 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -36,94 +36,60 @@ struct llama_ubatch {
3636
int8_t * output; // [n_tokens] | i | -
3737
};
3838

39-
// a helper for sanitizing, fulfilling and splitting a batch
40-
class llama_batch_allocr {
41-
public:
42-
llama_batch_allocr(uint32_t n_pos_per_embd);
39+
struct llama_sbatch_seq {
40+
int32_t n_seq_id;
4341

44-
// sanitize and auto-gen missing data in the input batch
45-
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
46-
bool init(
47-
const llama_batch & batch_inp,
48-
const llama_vocab & vocab,
49-
const llama_memory_i * memory,
50-
uint32_t n_embd,
51-
bool output_all);
42+
llama_seq_id * seq_id;
5243

53-
const llama_batch & get_batch() const;
54-
55-
uint32_t get_n_tokens() const;
56-
uint32_t get_n_outputs() const;
57-
58-
// the array of output indices in the order they were encountered during the ubatch splitting
59-
std::vector<int32_t> & get_out_ids();
60-
61-
// min/max positions of each sequence in the current ubatch
62-
llama_pos seq_pos_min(llama_seq_id seq_id) const;
63-
llama_pos seq_pos_max(llama_seq_id seq_id) const;
64-
65-
// call once before splitting the batch to reset the internal state
66-
void split_reset();
67-
68-
// simple split, unknown number of sequence sets of unequal lengths
69-
llama_ubatch split_simple(uint32_t n_ubatch);
70-
71-
// make ubatches of equal-length sequences sets
72-
llama_ubatch split_equal(uint32_t n_ubatch);
73-
74-
// sequence-set-wise split - each ubatch contains a single sequence-set
75-
llama_ubatch split_seq(uint32_t n_ubatch);
76-
77-
// a helper method for creating a well-defined ubatch of tokens
78-
// TODO: support embeddings if needed in the future
79-
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
80-
81-
private:
82-
void clear();
83-
84-
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
85-
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
86-
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
87-
88-
// for debugging, start with LLAMA_BATCH_DEBUG=2
89-
void ubatch_print(const llama_ubatch & ubatch, int debug);
44+
size_t offset;
45+
size_t length;
46+
};
9047

91-
llama_batch batch;
48+
// sequence-length-aware batch splitting
49+
struct llama_sbatch {
50+
// tokens left in this batch
51+
size_t n_tokens;
9252

9353
// only for debugging purposes
9454
const llama_vocab * vocab;
9555

96-
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98-
const uint32_t n_pos_per_embd;
99-
100-
uint32_t n_embd;
101-
uint32_t n_outputs;
56+
// sorted indices into the batch
57+
std::vector<int64_t> ids;
58+
// batch indices of the output
59+
std::vector<int64_t> out_ids;
60+
std::vector<llama_sbatch_seq> seq;
10261

10362
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
10463

105-
std::vector<llama_pos> pos;
106-
std::vector<int32_t> n_seq_id;
107-
std::vector<llama_seq_id *> seq_id;
108-
std::vector<llama_seq_id> seq_id_unq;
109-
std::vector<int32_t> seq_idx;
110-
std::vector<int8_t> output;
64+
// buffers for the ubatches
65+
// TODO: very hacky, this needs a complete rework
66+
struct ubatch_data {
67+
std::vector<llama_token> token;
68+
std::vector<float> embd;
69+
std::vector<llama_pos> pos;
70+
std::vector<int32_t> n_seq_id;
71+
std::vector<llama_seq_id *> seq_id;
72+
std::vector<int8_t> output;
73+
};
74+
75+
std::vector<ubatch_data> udatas;
11176

112-
using pos_set_t = std::set<llama_pos>;
113-
using seq_cpl_t = std::vector<bool>;
77+
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
11478

115-
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116-
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
79+
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
11780

118-
using idx_vec_t = std::vector<int32_t>;
119-
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
81+
// simple split, unknown number of sequences of unequal lengths
82+
llama_ubatch split_simple(size_t n_ubatch);
12083

121-
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
84+
// make batches of equal-length sequences
85+
llama_ubatch split_equal(size_t n_ubatch);
12286

123-
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
87+
// sequence-wise split
88+
llama_ubatch split_seq(size_t n_ubatch);
12489

125-
// batch indices of the output
126-
std::vector<int32_t> out_ids;
90+
llama_sbatch() = default;
91+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
92+
};
12793

12894
// used[i] indicates if token i has already been used in a previous ubatch
12995
std::vector<bool> used;

0 commit comments

Comments
 (0)