|
2 | 2 |
|
3 | 3 | #include "llama.h"
|
4 | 4 |
|
5 |
| -#include "llama-cparams.h" |
6 |
| - |
7 | 5 | #include <array>
|
8 | 6 | #include <vector>
|
9 |
| -#include <set> |
10 |
| -#include <bitset> |
11 |
| -#include <unordered_map> |
12 | 7 |
|
13 |
| -// keep this struct lightweight |
14 |
| -// it points to data in `llama_batch_allocr` |
| 8 | +// very similar to llama_batch, |
| 9 | +// but has more metadata about sequences |
15 | 10 | struct llama_ubatch {
|
16 | 11 | bool equal_seqs;
|
17 | 12 | // TODO: whole_seqs for embeddings?
|
18 | 13 |
|
19 | 14 | uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
20 |
| - uint32_t n_seq_tokens; // tokens per sequence set |
21 |
| - uint32_t n_seqs; // sequence sets in the ubatch |
22 |
| - uint32_t n_seqs_unq; // unique sequence ids in the ubatch |
23 |
| - |
24 |
| - // seq_id_unq: unique sequence ids in the ubatch |
25 |
| - // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq) |
26 |
| - // used for extracting sequence pooled embeddings |
27 |
| - |
28 |
| - // // size | idx | val |
29 |
| - llama_token * token; // [n_tokens] | i | id, token |
30 |
| - float * embd; // [n_embd, n_tokens] | i | embd |
31 |
| - llama_pos * pos; // [n_tokens] | i | pos |
32 |
| - int32_t * n_seq_id; // [n_tokens] | i | - |
33 |
| - llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id |
34 |
| - llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id |
35 |
| - int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx |
36 |
| - int8_t * output; // [n_tokens] | i | - |
| 15 | + uint32_t n_seq_tokens; // tokens per sequence |
| 16 | + uint32_t n_seqs; |
| 17 | + |
| 18 | + llama_token * token; // [n_tokens] |
| 19 | + float * embd; // [n_embd, n_tokens] |
| 20 | + llama_pos * pos; // [n_tokens] |
| 21 | + int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence |
| 22 | + llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; |
| 23 | + int8_t * output; // [n_tokens] |
37 | 24 | };
|
38 | 25 |
|
39 |
| -// a helper for sanitizing, fulfilling and splitting a batch |
40 |
| -class llama_batch_allocr { |
41 |
| -public: |
42 |
| - llama_batch_allocr(uint32_t n_pos_per_embd); |
43 |
| - |
44 |
| - // sanitize and auto-gen missing data in the input batch |
45 |
| - // memory is optional. if provided will be used to check for sequence continuity and to determine the positions |
46 |
| - bool init( |
47 |
| - const llama_batch & batch_inp, |
48 |
| - const llama_vocab & vocab, |
49 |
| - const llama_memory_i * memory, |
50 |
| - uint32_t n_embd, |
51 |
| - bool output_all); |
52 |
| - |
53 |
| - const llama_batch & get_batch() const; |
| 26 | +struct llama_sbatch_seq { |
| 27 | + int32_t n_seq_id; |
54 | 28 |
|
55 |
| - uint32_t get_n_tokens() const; |
56 |
| - uint32_t get_n_outputs() const; |
57 |
| - uint32_t get_n_used() const; |
| 29 | + llama_seq_id * seq_id; |
58 | 30 |
|
59 |
| - // the array of output indices in the order they were encountered during the ubatch splitting |
60 |
| - std::vector<int32_t> & get_out_ids(); |
| 31 | + size_t offset; |
| 32 | + size_t length; |
| 33 | +}; |
61 | 34 |
|
62 |
| - // min/max positions of each sequence in the current ubatch |
63 |
| - llama_pos seq_pos_min(llama_seq_id seq_id) const; |
64 |
| - llama_pos seq_pos_max(llama_seq_id seq_id) const; |
| 35 | +// sequence-length-aware batch splitting |
| 36 | +struct llama_sbatch { |
| 37 | + // tokens left in this batch |
| 38 | + size_t n_tokens; |
65 | 39 |
|
66 |
| - // call once before splitting the batch to reset the internal state |
67 |
| - void split_reset(); |
| 40 | + size_t n_embd; |
68 | 41 |
|
69 |
| - // simple split, unknown number of sequence sets of unequal lengths |
70 |
| - llama_ubatch split_simple(uint32_t n_ubatch); |
| 42 | + // sorted indices into the batch |
| 43 | + std::vector<int64_t> ids; |
| 44 | + // batch indices of the output |
| 45 | + std::vector<int64_t> out_ids; |
| 46 | + std::vector<llama_sbatch_seq> seq; |
71 | 47 |
|
72 |
| - // make ubatches of equal-length sequences sets |
73 |
| - // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids |
74 |
| - llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
| 48 | + const llama_batch * batch = nullptr; |
75 | 49 |
|
76 |
| - // sequence-set-wise split - each ubatch contains a single sequence-set |
77 |
| - llama_ubatch split_seq(uint32_t n_ubatch); |
| 50 | + // buffers for the ubatches |
| 51 | + // TODO: very hacky, this needs a complete rework |
| 52 | + struct ubatch_data { |
| 53 | + std::vector<llama_token> token; |
| 54 | + std::vector<float> embd; |
| 55 | + std::vector<llama_pos> pos; |
| 56 | + std::vector<int32_t> n_seq_id; |
| 57 | + std::vector<llama_seq_id *> seq_id; |
| 58 | + std::vector<int8_t> output; |
| 59 | + }; |
78 | 60 |
|
79 |
| - // a helper method for creating a well-defined ubatch of tokens |
80 |
| - // TODO: support embeddings if needed in the future |
81 |
| - llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
| 61 | + std::vector<ubatch_data> udatas; |
82 | 62 |
|
83 |
| -private: |
84 |
| - void clear(); |
| 63 | + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); |
85 | 64 |
|
86 |
| - // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs) |
87 |
| - // return llama_ubatch.n_tokens == 0 if the entire batch was consumed |
88 |
| - llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
| 65 | + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); |
89 | 66 |
|
90 |
| - // for debugging, start with LLAMA_BATCH_DEBUG=2 |
91 |
| - void ubatch_print(const llama_ubatch & ubatch, int debug); |
| 67 | + // simple split, unknown number of sequences of unequal lengths |
| 68 | + llama_ubatch split_simple(size_t n_ubatch); |
92 | 69 |
|
93 |
| - llama_batch batch; |
| 70 | + // make batches of equal-length sequences |
| 71 | + llama_ubatch split_equal(size_t n_ubatch); |
94 | 72 |
|
95 |
| - // only for debugging purposes |
96 |
| - const llama_vocab * vocab; |
| 73 | + // sequence-wise split |
| 74 | + llama_ubatch split_seq(size_t n_ubatch); |
97 | 75 |
|
98 |
| - // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd |
99 |
| - // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 |
100 |
| - const uint32_t n_pos_per_embd; |
| 76 | + llama_sbatch() = default; |
| 77 | + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); |
| 78 | +}; |
101 | 79 |
|
102 |
| - uint32_t n_embd; |
103 |
| - uint32_t n_outputs; |
| 80 | +// temporary allocate memory for the input batch if needed |
| 81 | +struct llama_batch_allocr { |
| 82 | + struct llama_batch batch; |
104 | 83 |
|
105 | 84 | std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
106 |
| - |
107 | 85 | std::vector<llama_pos> pos;
|
108 | 86 | std::vector<int32_t> n_seq_id;
|
109 | 87 | std::vector<llama_seq_id *> seq_id;
|
110 |
| - std::vector<llama_seq_id> seq_id_unq; |
111 |
| - std::vector<int32_t> seq_idx; |
112 |
| - std::vector<int8_t> output; |
113 |
| - |
114 |
| - using pos_set_t = std::set<llama_pos>; |
115 |
| - using seq_cpl_t = std::vector<bool>; |
116 |
| - |
117 |
| - // helper flag to quickly determine if there are any coupled sequences in the batch |
118 |
| - bool has_cpl; |
119 |
| - |
120 |
| - std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s |
121 |
| - std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |
122 |
| - |
123 |
| - using idx_vec_t = std::vector<int32_t>; |
124 |
| - using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
125 |
| - |
126 |
| - std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i |
127 |
| - |
128 |
| - std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears |
129 |
| - |
130 |
| - // batch indices of the output |
131 |
| - std::vector<int32_t> out_ids; |
132 |
| - |
133 |
| - uint32_t n_used; |
134 |
| - |
135 |
| - // used[i] indicates if token i has already been used in a previous ubatch |
136 |
| - std::vector<bool> used; |
137 |
| - |
138 |
| - // llama_ubatch points to this data: |
139 |
| - struct ubatch { |
140 |
| - std::vector<llama_token> token; |
141 |
| - std::vector<float> embd; |
142 |
| - std::vector<llama_pos> pos; |
143 |
| - std::vector<int32_t> n_seq_id; |
144 |
| - std::vector<llama_seq_id *> seq_id; |
145 |
| - std::vector<llama_seq_id> seq_id_unq; |
146 |
| - std::vector<int32_t> seq_idx; |
147 |
| - std::vector<int8_t> output; |
148 |
| - }; |
149 |
| - |
150 |
| - // current splitting state: |
151 |
| - std::vector<ubatch> ubatches; |
| 88 | + std::vector<int8_t> logits; |
152 | 89 |
|
153 |
| - int debug; |
| 90 | + // optionally fulfill the batch returned by llama_batch_get_one |
| 91 | + llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |
154 | 92 | };
|
0 commit comments