Skip to content

Commit 4fc4b0d

Browse files
committed
cont : add comments
1 parent 6a50f45 commit 4fc4b0d

File tree

3 files changed

+67
-17
lines changed

3 files changed

+67
-17
lines changed

src/llama-batch.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ bool llama_batch_allocr::init(
8888
llama_pos p0[LLAMA_MAX_SEQ];
8989
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
9090
if (!memory) {
91+
// if no memory -> start from 0
9192
p0[s] = 0;
9293
} else {
9394
p0[s] = memory->seq_pos_max(s) + 1;
@@ -99,8 +100,11 @@ bool llama_batch_allocr::init(
99100

100101
pos[i] = p0[seq_id];
101102

103+
// update the starting position for all sequences that are assigned to the this token
102104
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
103-
p0[batch.seq_id[i][s]] = pos[i] + 1;
105+
const llama_seq_id seq_id = batch.seq_id[i][s];
106+
107+
p0[seq_id] = pos[i] + 1;
104108
}
105109
}
106110

@@ -141,6 +145,7 @@ bool llama_batch_allocr::init(
141145

142146
this->n_embd = n_embd;
143147

148+
// count the outputs in this batch
144149
for (int32_t i = 0; i < batch.n_tokens; ++i) {
145150
n_outputs += batch.logits[i] != 0;
146151
}
@@ -159,22 +164,23 @@ bool llama_batch_allocr::init(
159164
// mark that sequence s1 is coupled to s0
160165
seq_cpl[s1][s0] = true;
161166

162-
// note: the other way around is not necessary for now
167+
// note: tracking the other way around is not necessary for now
163168
//seq_cpl[s0][s1] = true;
164169
}
165170
}
166171
}
167172

173+
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
168174
{
169175
seq_set_t seq_set_unq;
170176

171177
for (int32_t i = 0; i < batch.n_tokens; ++i) {
172178
seq_set_t cur;
173179
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
174-
const llama_seq_id s0 = batch.seq_id[i][s];
180+
const llama_seq_id seq_id = batch.seq_id[i][s];
175181

176-
cur.set(s0);
177-
seq_set_unq.set(s0);
182+
cur .set(seq_id);
183+
seq_set_unq.set(seq_id);
178184
}
179185

180186
seq_set.push_back(cur);
@@ -263,6 +269,15 @@ bool llama_batch_allocr::init(
263269
}
264270
}
265271

272+
// disallow disjoint sequence sets:
273+
//
274+
// invalid: x
275+
// i: 0 1 2 ...
276+
// ---------------------------------------
277+
// seq_id[i][0]: 0 0 1
278+
// seq_id[i][1]: 1 1 2
279+
// seq_id[i][2]: 2
280+
//
266281
{
267282
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
268283
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
@@ -368,11 +383,13 @@ void llama_batch_allocr::split_reset() {
368383
}
369384

370385
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
386+
// find the first unused token
371387
uint32_t cur_idx = 0;
372388
while (cur_idx < used.size() && used[cur_idx]) {
373389
++cur_idx;
374390
}
375391

392+
// we are done
376393
if (cur_idx >= used.size()) {
377394
return {};
378395
}
@@ -401,7 +418,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
401418
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
402419
std::vector<seq_set_t> cur_seq_set;
403420

404-
// determine the sequence sets participating in this ubatch
421+
// determine the non-overlapping sequence sets participating in this ubatch
405422
for (int32_t i = 0; i < batch.n_tokens; ++i) {
406423
if (used[i]) {
407424
continue;
@@ -428,10 +445,12 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
428445

429446
const uint32_t n_seqs = cur_seq_set.size();
430447

448+
// we are done
431449
if (n_seqs == 0) {
432450
return {};
433451
}
434452

453+
// the current batch index of each sequence set
435454
std::vector<int32_t> cur_idx(n_seqs, 0);
436455

437456
for (uint32_t s = 0; s < n_seqs; ++s) {
@@ -440,9 +459,13 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
440459
}
441460
}
442461

462+
// the list of batch indices for each sequence set
463+
// at the end we will concat these to get the final ubatch
443464
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
444465

445466
while (true) {
467+
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
468+
// if we haven't reached n_ubatch
446469
bool can_expand = true;
447470

448471
for (uint32_t s = 0; s < n_seqs; ++s) {
@@ -458,6 +481,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
458481

459482
for (uint32_t s = 0; s < n_seqs; ++s) {
460483
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
484+
461485
idxs_per_seq[s].push_back(idx);
462486

463487
used[idx] = true;
@@ -470,6 +494,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470494
}
471495
}
472496

497+
// concat the per-sequence-set lists
473498
std::vector<int32_t> idxs;
474499

475500
for (uint32_t s = 0; s < n_seqs; ++s) {
@@ -480,15 +505,19 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
480505
}
481506

482507
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
508+
// find the first unused token
483509
uint32_t cur_idx = 0;
484510
while (cur_idx < used.size() && used[cur_idx]) {
485511
++cur_idx;
486512
}
487513

514+
// we are done
488515
if (cur_idx >= used.size()) {
489516
return {};
490517
}
491518

519+
// this is the starting sequence set
520+
// we allow adding tokens only if their sequence set is a subset of the current sequence set
492521
auto cur_seq_set = seq_set[cur_idx];
493522

494523
std::vector<int32_t> idxs;

src/llama-batch.h

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <bitset>
1111
#include <unordered_map>
1212

13+
// keep this struct lightweight
14+
// it points to data in `llama_batch_allocr`
1315
struct llama_ubatch {
1416
bool equal_seqs;
1517
// TODO: whole_seqs for embeddings?
@@ -19,14 +21,19 @@ struct llama_ubatch {
1921
uint32_t n_seqs; // sequence sets in the ubatch
2022
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
2123

22-
llama_token * token; // [n_tokens]
23-
float * embd; // [n_embd, n_tokens]
24-
llama_pos * pos; // [n_tokens]
25-
int32_t * n_seq_id; // [n_tokens]
26-
llama_seq_id ** seq_id; // [n_tokens]
27-
llama_seq_id * seq_id_unq; // [n_seqs_unq]
28-
int32_t * seq_idx; // [LLAMA_MAX_SEQ]
29-
int8_t * output; // [n_tokens]
24+
// seq_id_unq: unique sequence ids in the ubatch
25+
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26+
// used for extracting sequence pooled embeddings
27+
28+
// // size | idx | val
29+
llama_token * token; // [n_tokens] | i | id, token
30+
float * embd; // [n_embd, n_tokens] | i | embd
31+
llama_pos * pos; // [n_tokens] | i | pos
32+
int32_t * n_seq_id; // [n_tokens] | i | -
33+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36+
int8_t * output; // [n_tokens] | i | -
3037
};
3138

3239
// a helper for sanitizing, fulfilling and splitting a batch
@@ -48,11 +55,14 @@ class llama_batch_allocr {
4855
uint32_t get_n_tokens() const;
4956
uint32_t get_n_outputs() const;
5057

58+
// the array of output indices in the order they were encountered during the ubatch splitting
5159
std::vector<int32_t> & get_out_ids();
5260

61+
// min/max positions of each sequence in the current ubatch
5362
llama_pos seq_pos_min(llama_seq_id seq_id) const;
5463
llama_pos seq_pos_max(llama_seq_id seq_id) const;
5564

65+
// call once before splitting the batch to reset the internal state
5666
void split_reset();
5767

5868
// simple split, unknown number of sequences of unequal lengths
@@ -62,15 +72,21 @@ class llama_batch_allocr {
6272
llama_ubatch split_equal(uint32_t n_ubatch);
6373

6474
// sequence-wise split - each ubatch contains a single sequence
75+
// TODO: fit more than one full sequence, as long as they fit in the ubatch
6576
llama_ubatch split_seq(uint32_t n_ubatch);
6677

78+
// a helper method for creating a well-defined ubatch of tokens
79+
// TODO: support embeddings if needed in the future
6780
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
6881

6982
private:
7083
void clear();
7184

85+
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
86+
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
7287
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
7388

89+
// for debugging, start with LLAMA_BATCH_DEBUG=2
7490
void ubatch_print(const llama_ubatch & ubatch, int debug);
7591

7692
llama_batch batch;
@@ -99,15 +115,17 @@ class llama_batch_allocr {
99115
using idx_vec_t = std::vector<int32_t>;
100116
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
101117

102-
std::vector<seq_set_t> seq_set;
118+
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
103119

104-
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map;
120+
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
105121

106122
// batch indices of the output
107123
std::vector<int32_t> out_ids;
108124

125+
// used[i] indicates if token i has already been used in a previous ubatch
109126
std::vector<bool> used;
110127

128+
// llama_ubatch points to this data:
111129
struct ubatch {
112130
std::vector<llama_token> token;
113131
std::vector<float> embd;
@@ -119,6 +137,7 @@ class llama_batch_allocr {
119137
std::vector<int8_t> output;
120138
};
121139

140+
// current splitting state:
122141
std::vector<ubatch> ubatches;
123142

124143
int debug;

src/llama-context.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
865865
cross.seq_ids_enc.resize(n_tokens);
866866
for (uint32_t i = 0; i < n_tokens; i++) {
867867
cross.seq_ids_enc[i].clear();
868+
868869
for (int s = 0; s < batch.n_seq_id[i]; s++) {
869-
llama_seq_id seq_id = batch.seq_id[i][s];
870+
const llama_seq_id seq_id = batch.seq_id[i][s];
871+
870872
cross.seq_ids_enc[i].insert(seq_id);
871873
}
872874
}

0 commit comments

Comments
 (0)