@@ -88,6 +88,7 @@ bool llama_batch_allocr::init(
88
88
llama_pos p0[LLAMA_MAX_SEQ];
89
89
for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
90
90
if (!memory) {
91
+ // if no memory -> start from 0
91
92
p0[s] = 0 ;
92
93
} else {
93
94
p0[s] = memory->seq_pos_max (s) + 1 ;
@@ -99,8 +100,11 @@ bool llama_batch_allocr::init(
99
100
100
101
pos[i] = p0[seq_id];
101
102
103
+ // update the starting position for all sequences that are assigned to the this token
102
104
for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
103
- p0[batch.seq_id [i][s]] = pos[i] + 1 ;
105
+ const llama_seq_id seq_id = batch.seq_id [i][s];
106
+
107
+ p0[seq_id] = pos[i] + 1 ;
104
108
}
105
109
}
106
110
@@ -141,6 +145,7 @@ bool llama_batch_allocr::init(
141
145
142
146
this ->n_embd = n_embd;
143
147
148
+ // count the outputs in this batch
144
149
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
145
150
n_outputs += batch.logits [i] != 0 ;
146
151
}
@@ -159,22 +164,23 @@ bool llama_batch_allocr::init(
159
164
// mark that sequence s1 is coupled to s0
160
165
seq_cpl[s1][s0] = true ;
161
166
162
- // note: the other way around is not necessary for now
167
+ // note: tracking the other way around is not necessary for now
163
168
// seq_cpl[s0][s1] = true;
164
169
}
165
170
}
166
171
}
167
172
173
+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
168
174
{
169
175
seq_set_t seq_set_unq;
170
176
171
177
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
172
178
seq_set_t cur;
173
179
for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
174
- const llama_seq_id s0 = batch.seq_id [i][s];
180
+ const llama_seq_id seq_id = batch.seq_id [i][s];
175
181
176
- cur.set (s0 );
177
- seq_set_unq.set (s0 );
182
+ cur .set (seq_id );
183
+ seq_set_unq.set (seq_id );
178
184
}
179
185
180
186
seq_set.push_back (cur);
@@ -263,6 +269,15 @@ bool llama_batch_allocr::init(
263
269
}
264
270
}
265
271
272
+ // disallow disjoint sequence sets:
273
+ //
274
+ // invalid: x
275
+ // i: 0 1 2 ...
276
+ // ---------------------------------------
277
+ // seq_id[i][0]: 0 0 1
278
+ // seq_id[i][1]: 1 1 2
279
+ // seq_id[i][2]: 2
280
+ //
266
281
{
267
282
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
268
283
for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
@@ -368,11 +383,13 @@ void llama_batch_allocr::split_reset() {
368
383
}
369
384
370
385
llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
386
+ // find the first unused token
371
387
uint32_t cur_idx = 0 ;
372
388
while (cur_idx < used.size () && used[cur_idx]) {
373
389
++cur_idx;
374
390
}
375
391
392
+ // we are done
376
393
if (cur_idx >= used.size ()) {
377
394
return {};
378
395
}
@@ -401,7 +418,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
401
418
llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
402
419
std::vector<seq_set_t > cur_seq_set;
403
420
404
- // determine the sequence sets participating in this ubatch
421
+ // determine the non-overlapping sequence sets participating in this ubatch
405
422
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
406
423
if (used[i]) {
407
424
continue ;
@@ -428,10 +445,12 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
428
445
429
446
const uint32_t n_seqs = cur_seq_set.size ();
430
447
448
+ // we are done
431
449
if (n_seqs == 0 ) {
432
450
return {};
433
451
}
434
452
453
+ // the current batch index of each sequence set
435
454
std::vector<int32_t > cur_idx (n_seqs, 0 );
436
455
437
456
for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -440,9 +459,13 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
440
459
}
441
460
}
442
461
462
+ // the list of batch indices for each sequence set
463
+ // at the end we will concat these to get the final ubatch
443
464
std::vector<idx_vec_t > idxs_per_seq (n_seqs);
444
465
445
466
while (true ) {
467
+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
468
+ // if we haven't reached n_ubatch
446
469
bool can_expand = true ;
447
470
448
471
for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -458,6 +481,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
458
481
459
482
for (uint32_t s = 0 ; s < n_seqs; ++s) {
460
483
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
484
+
461
485
idxs_per_seq[s].push_back (idx);
462
486
463
487
used[idx] = true ;
@@ -470,6 +494,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470
494
}
471
495
}
472
496
497
+ // concat the per-sequence-set lists
473
498
std::vector<int32_t > idxs;
474
499
475
500
for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -480,15 +505,19 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
480
505
}
481
506
482
507
llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
508
+ // find the first unused token
483
509
uint32_t cur_idx = 0 ;
484
510
while (cur_idx < used.size () && used[cur_idx]) {
485
511
++cur_idx;
486
512
}
487
513
514
+ // we are done
488
515
if (cur_idx >= used.size ()) {
489
516
return {};
490
517
}
491
518
519
+ // this is the starting sequence set
520
+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
492
521
auto cur_seq_set = seq_set[cur_idx];
493
522
494
523
std::vector<int32_t > idxs;
0 commit comments