Skip to content

Commit 99cfd5e

Browse files
ggerganovqnixsynapse
authored andcommitted
batch : remove logits_all flag (ggml-org#14141)
ggml-ci
1 parent c022829 commit 99cfd5e

10 files changed

+671
-1219
lines changed

src/llama-batch.cpp

Lines changed: 269 additions & 763 deletions
Large diffs are not rendered by default.

src/llama-batch.h

Lines changed: 56 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2,153 +2,91 @@
22

33
#include "llama.h"
44

5-
#include "llama-cparams.h"
6-
75
#include <array>
86
#include <vector>
9-
#include <set>
10-
#include <bitset>
11-
#include <unordered_map>
127

13-
// keep this struct lightweight
14-
// it points to data in `llama_batch_allocr`
8+
// very similar to llama_batch,
9+
// but has more metadata about sequences
1510
struct llama_ubatch {
1611
bool equal_seqs;
1712
// TODO: whole_seqs for embeddings?
1813

1914
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
20-
uint32_t n_seq_tokens; // tokens per sequence set
21-
uint32_t n_seqs; // sequence sets in the ubatch
22-
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
23-
24-
// seq_id_unq: unique sequence ids in the ubatch
25-
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26-
// used for extracting sequence pooled embeddings
27-
28-
// // size | idx | val
29-
llama_token * token; // [n_tokens] | i | id, token
30-
float * embd; // [n_embd, n_tokens] | i | embd
31-
llama_pos * pos; // [n_tokens] | i | pos
32-
int32_t * n_seq_id; // [n_tokens] | i | -
33-
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34-
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35-
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36-
int8_t * output; // [n_tokens] | i | -
15+
uint32_t n_seq_tokens; // tokens per sequence
16+
uint32_t n_seqs;
17+
18+
llama_token * token; // [n_tokens]
19+
float * embd; // [n_embd, n_tokens]
20+
llama_pos * pos; // [n_tokens]
21+
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22+
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
23+
int8_t * output; // [n_tokens]
3724
};
3825

39-
// a helper for sanitizing, fulfilling and splitting a batch
40-
class llama_batch_allocr {
41-
public:
42-
llama_batch_allocr(uint32_t n_pos_per_embd);
43-
44-
// sanitize and auto-gen missing data in the input batch
45-
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
46-
bool init(
47-
const llama_batch & batch_inp,
48-
const llama_vocab & vocab,
49-
const llama_memory_i * memory,
50-
uint32_t n_embd,
51-
bool output_all);
52-
53-
const llama_batch & get_batch() const;
26+
struct llama_sbatch_seq {
27+
int32_t n_seq_id;
5428

55-
uint32_t get_n_tokens() const;
56-
uint32_t get_n_outputs() const;
57-
uint32_t get_n_used() const;
29+
llama_seq_id * seq_id;
5830

59-
// the array of output indices in the order they were encountered during the ubatch splitting
60-
std::vector<int32_t> & get_out_ids();
31+
size_t offset;
32+
size_t length;
33+
};
6134

62-
// min/max positions of each sequence in the current ubatch
63-
llama_pos seq_pos_min(llama_seq_id seq_id) const;
64-
llama_pos seq_pos_max(llama_seq_id seq_id) const;
35+
// sequence-length-aware batch splitting
36+
struct llama_sbatch {
37+
// tokens left in this batch
38+
size_t n_tokens;
6539

66-
// call once before splitting the batch to reset the internal state
67-
void split_reset();
40+
size_t n_embd;
6841

69-
// simple split, unknown number of sequence sets of unequal lengths
70-
llama_ubatch split_simple(uint32_t n_ubatch);
42+
// sorted indices into the batch
43+
std::vector<int64_t> ids;
44+
// batch indices of the output
45+
std::vector<int64_t> out_ids;
46+
std::vector<llama_sbatch_seq> seq;
7147

72-
// make ubatches of equal-length sequences sets
73-
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
74-
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
48+
const llama_batch * batch = nullptr;
7549

76-
// sequence-set-wise split - each ubatch contains a single sequence-set
77-
llama_ubatch split_seq(uint32_t n_ubatch);
50+
// buffers for the ubatches
51+
// TODO: very hacky, this needs a complete rework
52+
struct ubatch_data {
53+
std::vector<llama_token> token;
54+
std::vector<float> embd;
55+
std::vector<llama_pos> pos;
56+
std::vector<int32_t> n_seq_id;
57+
std::vector<llama_seq_id *> seq_id;
58+
std::vector<int8_t> output;
59+
};
7860

79-
// a helper method for creating a well-defined ubatch of tokens
80-
// TODO: support embeddings if needed in the future
81-
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
61+
std::vector<ubatch_data> udatas;
8262

83-
private:
84-
void clear();
63+
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
8564

86-
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
87-
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
88-
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
65+
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
8966

90-
// for debugging, start with LLAMA_BATCH_DEBUG=2
91-
void ubatch_print(const llama_ubatch & ubatch, int debug);
67+
// simple split, unknown number of sequences of unequal lengths
68+
llama_ubatch split_simple(size_t n_ubatch);
9269

93-
llama_batch batch;
70+
// make batches of equal-length sequences
71+
llama_ubatch split_equal(size_t n_ubatch);
9472

95-
// only for debugging purposes
96-
const llama_vocab * vocab;
73+
// sequence-wise split
74+
llama_ubatch split_seq(size_t n_ubatch);
9775

98-
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
99-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
100-
const uint32_t n_pos_per_embd;
76+
llama_sbatch() = default;
77+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
78+
};
10179

102-
uint32_t n_embd;
103-
uint32_t n_outputs;
80+
// temporary allocate memory for the input batch if needed
81+
struct llama_batch_allocr {
82+
struct llama_batch batch;
10483

10584
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
106-
10785
std::vector<llama_pos> pos;
10886
std::vector<int32_t> n_seq_id;
10987
std::vector<llama_seq_id *> seq_id;
110-
std::vector<llama_seq_id> seq_id_unq;
111-
std::vector<int32_t> seq_idx;
112-
std::vector<int8_t> output;
113-
114-
using pos_set_t = std::set<llama_pos>;
115-
using seq_cpl_t = std::vector<bool>;
116-
117-
// helper flag to quickly determine if there are any coupled sequences in the batch
118-
bool has_cpl;
119-
120-
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
121-
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
122-
123-
using idx_vec_t = std::vector<int32_t>;
124-
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
125-
126-
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
127-
128-
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
129-
130-
// batch indices of the output
131-
std::vector<int32_t> out_ids;
132-
133-
uint32_t n_used;
134-
135-
// used[i] indicates if token i has already been used in a previous ubatch
136-
std::vector<bool> used;
137-
138-
// llama_ubatch points to this data:
139-
struct ubatch {
140-
std::vector<llama_token> token;
141-
std::vector<float> embd;
142-
std::vector<llama_pos> pos;
143-
std::vector<int32_t> n_seq_id;
144-
std::vector<llama_seq_id *> seq_id;
145-
std::vector<llama_seq_id> seq_id_unq;
146-
std::vector<int32_t> seq_idx;
147-
std::vector<int8_t> output;
148-
};
149-
150-
// current splitting state:
151-
std::vector<ubatch> ubatches;
88+
std::vector<int8_t> logits;
15289

153-
int debug;
90+
// optionally fulfill the batch returned by llama_batch_get_one
91+
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
15492
};

0 commit comments

Comments
 (0)