7
7
#include < array>
8
8
#include < vector>
9
9
#include < set>
10
+ #include < bitset>
11
+ #include < unordered_map>
10
12
13
+ // keep this struct lightweight
14
+ // it points to data in `llama_batch_allocr`
11
15
// keep this struct lightweight
12
16
// it points to data in `llama_batch_allocr`
13
17
struct llama_ubatch {
@@ -19,105 +23,127 @@ struct llama_ubatch {
19
23
uint32_t n_seqs; // sequence sets in the ubatch
20
24
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
21
25
22
- llama_token * token; // [n_tokens]
23
- float * embd; // [n_embd, n_tokens]
24
- llama_pos * pos; // [n_tokens]
25
- int32_t * n_seq_id; // [n_seqs]
26
- llama_seq_id ** seq_id; // [n_seqs]
27
- int8_t * output; // [n_tokens]
28
- };
29
-
30
- struct llama_sbatch_seq {
31
- int32_t n_seq_id;
32
-
33
- llama_seq_id * seq_id;
34
-
35
- size_t offset;
36
- size_t length;
37
- };
38
-
39
- // sequence-length-aware batch splitting
40
- struct llama_sbatch {
41
- // tokens left in this batch
42
- size_t n_tokens;
43
-
44
- // only for debugging purposes
45
- const llama_vocab * vocab;
46
-
47
- // sorted indices into the batch
48
- std::vector<int64_t > ids;
49
- // batch indices of the output
50
- std::vector<int64_t > out_ids;
51
- std::vector<llama_sbatch_seq> seq;
52
-
53
- std::array<llama_seq_id, 1 > seq_id_0 = { 0 }; // default sequence id
54
-
55
- // buffers for the ubatches
56
- // TODO: very hacky, this needs a complete rework
57
- struct ubatch_data {
58
- std::vector<llama_token> token;
59
- std::vector<float > embd;
60
- std::vector<llama_pos> pos;
61
- std::vector<int32_t > n_seq_id;
62
- std::vector<llama_seq_id *> seq_id;
63
- std::vector<int8_t > output;
64
- };
65
-
66
- std::vector<ubatch_data> udatas;
67
-
68
- llama_ubatch reserve_ubatch (size_t n_ubatch, bool has_embd = false );
69
-
70
- void add_seq_to_ubatch (llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
71
-
72
- // simple split, unknown number of sequences of unequal lengths
73
- llama_ubatch split_simple (size_t n_ubatch);
74
-
75
- // make batches of equal-length sequences
76
- llama_ubatch split_equal (size_t n_ubatch);
77
-
78
- // sequence-wise split
79
- llama_ubatch split_seq (size_t n_ubatch);
80
-
81
- llama_sbatch () = default ;
82
- llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split = false );
26
+ // seq_id_unq: unique sequence ids in the ubatch
27
+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
28
+ // used for extracting sequence pooled embeddings
29
+
30
+ // // size | idx | val
31
+ llama_token * token; // [n_tokens] | i | id, token
32
+ float * embd; // [n_embd, n_tokens] | i | embd
33
+ llama_pos * pos; // [n_tokens] | i | pos
34
+ int32_t * n_seq_id; // [n_tokens] | i | -
35
+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
36
+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
37
+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
38
+ int8_t * output; // [n_tokens] | i | -
83
39
};
84
40
85
- // a helper for sanitizing and fulfilling a batch
41
+ // a helper for sanitizing, fulfilling and splitting a batch
86
42
class llama_batch_allocr {
87
43
public:
88
- llama_batch_allocr ();
44
+ llama_batch_allocr (uint32_t n_pos_per_embd );
89
45
90
46
// sanitize and auto-gen missing data in the input batch
91
47
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
92
48
bool init (
93
49
const llama_batch & batch_inp,
94
50
const llama_vocab & vocab,
95
51
const llama_memory_i * memory,
96
- bool embd_all);
52
+ uint32_t n_embd,
53
+ bool output_all);
97
54
98
55
const llama_batch & get_batch () const ;
99
56
57
+ uint32_t get_n_tokens () const ;
100
58
uint32_t get_n_outputs () const ;
101
59
60
+ // the array of output indices in the order they were encountered during the ubatch splitting
61
+ std::vector<int32_t > & get_out_ids ();
62
+
63
+ // min/max positions of each sequence in the current ubatch
102
64
llama_pos seq_pos_min (llama_seq_id seq_id) const ;
103
65
llama_pos seq_pos_max (llama_seq_id seq_id) const ;
104
66
67
+ // call once before splitting the batch to reset the internal state
68
+ void split_reset ();
69
+
70
+ // simple split, unknown number of sequence sets of unequal lengths
71
+ llama_ubatch split_simple (uint32_t n_ubatch);
72
+
73
+ // make ubatches of equal-length sequences sets
74
+ llama_ubatch split_equal (uint32_t n_ubatch);
75
+
76
+ // sequence-set-wise split - each ubatch contains a single sequence-set
77
+ llama_ubatch split_seq (uint32_t n_ubatch);
78
+
79
+ // a helper method for creating a well-defined ubatch of tokens
80
+ // TODO: support embeddings if needed in the future
81
+ llama_ubatch ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs);
82
+
105
83
private:
106
84
void clear ();
107
85
86
+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
87
+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
88
+ llama_ubatch ubatch_add (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs);
89
+
90
+ // for debugging, start with LLAMA_BATCH_DEBUG=2
91
+ void ubatch_print (const llama_ubatch & ubatch, int debug);
92
+
108
93
llama_batch batch;
109
94
95
+ // only for debugging purposes
96
+ const llama_vocab * vocab;
97
+
98
+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
99
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
100
+ const uint32_t n_pos_per_embd;
101
+
102
+ uint32_t n_embd;
110
103
uint32_t n_outputs;
111
104
112
105
std::array<llama_seq_id, 1 > seq_id_0 = { 0 }; // default sequence id
113
106
114
107
std::vector<llama_pos> pos;
115
108
std::vector<int32_t > n_seq_id;
116
109
std::vector<llama_seq_id *> seq_id;
110
+ std::vector<llama_seq_id> seq_id_unq;
111
+ std::vector<int32_t > seq_idx;
117
112
std::vector<int8_t > output;
118
113
119
- std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
120
- std::vector<std::vector<bool >> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
114
+ using pos_set_t = std::set<llama_pos>;
115
+ using seq_cpl_t = std::vector<bool >;
116
+
117
+ std::vector<pos_set_t > seq_pos; // seq_pos[s]: the set of positions in sequence s
118
+ std::vector<seq_cpl_t > seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
119
+
120
+ using idx_vec_t = std::vector<int32_t >;
121
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
122
+
123
+ std::vector<seq_set_t > seq_set; // seq_set[i]: the sequence set of token i
124
+
125
+ std::unordered_map<seq_set_t , idx_vec_t > seq_set_map; // the indices at which the sequence set appears
126
+
127
+ // batch indices of the output
128
+ std::vector<int32_t > out_ids;
129
+
130
+ // used[i] indicates if token i has already been used in a previous ubatch
131
+ std::vector<bool > used;
132
+
133
+ // llama_ubatch points to this data:
134
+ struct ubatch {
135
+ std::vector<llama_token> token;
136
+ std::vector<float > embd;
137
+ std::vector<llama_pos> pos;
138
+ std::vector<int32_t > n_seq_id;
139
+ std::vector<llama_seq_id *> seq_id;
140
+ std::vector<llama_seq_id> seq_id_unq;
141
+ std::vector<int32_t > seq_idx;
142
+ std::vector<int8_t > output;
143
+ };
144
+
145
+ // current splitting state:
146
+ std::vector<ubatch> ubatches;
121
147
122
148
int debug;
123
149
};
0 commit comments