3
3
#include " llama-impl.h"
4
4
#include " llama-cparams.h"
5
5
#include " llama-vocab.h"
6
+ #include " llama-memory.h"
6
7
7
8
#include < cassert>
8
9
#include < cstring>
@@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
287
288
llama_batch_allocr::llama_batch_allocr () {
288
289
const char * LLAMA_BATCH_DEBUG = getenv (" LLAMA_BATCH_DEBUG" );
289
290
debug = LLAMA_BATCH_DEBUG ? atoi (LLAMA_BATCH_DEBUG) : 0 ;
291
+
292
+ seq_pos.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
293
+ seq_cpl.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
294
+ for (auto & cur : seq_cpl) {
295
+ cur.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
296
+ }
290
297
}
291
298
292
- bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
299
+ bool llama_batch_allocr::init (
300
+ const llama_batch & batch_inp,
301
+ const llama_vocab & vocab,
302
+ const llama_memory_i * memory) {
293
303
clear ();
294
304
295
305
batch = batch_inp;
296
306
297
307
GGML_ASSERT (batch.n_tokens > 0 );
298
308
299
- if (!batch.pos ) {
300
- if (batch.seq_id ) {
301
- LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
302
- return false ;
303
- }
304
- }
309
+ //
310
+ // validate input batch
311
+ //
305
312
306
313
if (batch.token ) {
307
314
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
323
330
}
324
331
}
325
332
326
- if (!batch.pos ) {
327
- assert (p0 >= 0 );
328
- pos.resize (batch.n_tokens );
329
- for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
330
- pos[i] = p0 + i;
331
- }
332
- batch.pos = pos.data ();
333
- }
333
+ //
334
+ // auto-generate missing fields
335
+ //
334
336
335
337
if (!batch.n_seq_id ) {
336
338
n_seq_id.resize (batch.n_tokens );
@@ -349,20 +351,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
349
351
batch.seq_id = seq_id.data ();
350
352
}
351
353
354
+ if (!batch.pos ) {
355
+ pos.resize (batch.n_tokens );
356
+
357
+ // initialize the starting position for each sequence based on the positions in the memory
358
+ llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359
+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
360
+ if (!memory) {
361
+ p0[s] = 0 ;
362
+ } else {
363
+ p0[s] = memory->seq_pos_max (s) + 1 ;
364
+ }
365
+ }
366
+
367
+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
368
+ const llama_seq_id seq_id = batch.seq_id [i][0 ];
369
+
370
+ pos[i] = p0[seq_id];
371
+
372
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
373
+ p0[batch.seq_id [i][s]] = pos[i] + 1 ;
374
+ }
375
+ }
376
+
377
+ batch.pos = pos.data ();
378
+ }
379
+
352
380
if (!batch.logits ) {
353
381
// by default return the output only for the last token
354
382
output.resize (batch.n_tokens );
355
383
output[output.size () - 1 ] = true ;
356
384
batch.logits = output.data ();
357
385
}
358
386
387
+ //
388
+ // compute stats
389
+ //
390
+
359
391
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
360
392
n_outputs += batch.logits [i] != 0 ;
361
393
}
362
394
395
+ // determine coupled sequences
396
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
397
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
398
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
399
+ seq_pos[batch.seq_id [i][s]].insert (batch.pos [i]);
400
+
401
+ if (s > 0 ) {
402
+ const llama_seq_id s0 = batch.seq_id [i][0 ];
403
+ const llama_seq_id s1 = batch.seq_id [i][s];
404
+
405
+ // mark that sequence s1 is coupled to s0
406
+ seq_cpl[s1][s0] = true ;
407
+
408
+ // note: the other way around is not necessary for now
409
+ // seq_cpl[s0][s1] = true;
410
+ }
411
+ }
412
+ }
413
+
363
414
if (debug > 0 ) {
364
- LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d) :\n " , __func__, p0 );
365
- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
415
+ LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
416
+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
366
417
LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
367
418
LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
368
419
LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
@@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
404
455
batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
405
456
}
406
457
LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
458
+
459
+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
460
+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
461
+ if (seq_pos[s0].empty ()) {
462
+ continue ;
463
+ }
464
+
465
+ std::stringstream ss;
466
+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
467
+ if (seq_cpl[s0][s1]) {
468
+ ss << s1 << " " ;
469
+ }
470
+ }
471
+
472
+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
473
+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
474
+ }
475
+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
476
+ }
477
+ }
478
+
479
+ //
480
+ // consistency checks
481
+ //
482
+
483
+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
484
+ if (seq_pos[s].empty ()) {
485
+ continue ;
486
+ }
487
+
488
+ if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
489
+ LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
490
+ return false ;
491
+ }
492
+
493
+ if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
494
+ LLAMA_LOG_ERROR (" %s: sequence %d positions are not continuous\n " , __func__, s);
495
+ return false ;
496
+ }
497
+ }
498
+
499
+ if (memory) {
500
+ for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501
+ for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
502
+ if (seq_cpl[s0][s1]) {
503
+ if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
504
+ memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
505
+ LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
506
+ return false ;
507
+ }
508
+ }
509
+ }
407
510
}
408
511
}
409
512
@@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
418
521
return n_outputs;
419
522
}
420
523
524
+ llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
525
+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
526
+ }
527
+
528
+ llama_pos llama_batch_allocr::seq_pos_max (llama_seq_id seq_id) const {
529
+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
530
+ }
531
+
421
532
void llama_batch_allocr::clear () {
422
533
n_outputs = 0 ;
423
534
@@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
426
537
n_seq_id.clear ();
427
538
seq_id.clear ();
428
539
output.clear ();
540
+
541
+ for (auto & cur : seq_pos) {
542
+ cur.clear ();
543
+ }
544
+
545
+ for (auto & cur : seq_cpl) {
546
+ std::fill (cur.begin (), cur.end (), false );
547
+ }
429
548
}
430
549
431
550
//
0 commit comments