@@ -422,9 +422,8 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
422
422
}
423
423
424
424
bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
425
- const uint32_t n_seqs = ubatch.n_seqs ;
426
-
427
425
const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
426
+ const uint32_t n_seqs = ubatch.n_seqs ;
428
427
429
428
// if we have enough unused cells before the current head ->
430
429
// better to start searching from the beginning of the cache, hoping to fill it
@@ -444,9 +443,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
444
443
445
444
// everything should fit if all seq_ids are smaller than the max
446
445
for (uint32_t s = 0 ; s < n_seqs; ++s) {
447
- const uint32_t n_seq_id = ubatch.n_seq_id [s*n_seq_tokens];
446
+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
447
+ const uint32_t n_seq_id = ubatch.n_seq_id [i];
448
+
448
449
for (uint32_t j = 0 ; j < n_seq_id; ++j) {
449
- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens ][j];
450
+ const llama_seq_id seq_id = ubatch.seq_id [i ][j];
450
451
451
452
if (seq_id < 0 || (uint32_t ) seq_id >= size) {
452
453
// too big seq_id
@@ -505,7 +506,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
505
506
506
507
// find usable cell range
507
508
for (uint32_t s = 0 ; s < n_seqs; ++s) {
508
- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens][0 ];
509
+ const uint32_t i = s*n_seq_tokens;
510
+ const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
511
+
509
512
kv_cell & seq_meta = cells[seq_id];
510
513
bool has_cell = false ;
511
514
if (seq_meta.tail >= 0 ) {
@@ -529,7 +532,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
529
532
seq_meta.tail = next_empty_cell;
530
533
// find next empty cell
531
534
if (s + 1 < n_seqs) {
532
- for (uint32_t i = 0 ; i < size; ++i ) {
535
+ for (uint32_t j = 0 ; j < size; ++j ) {
533
536
next_empty_cell += 1 ;
534
537
if (next_empty_cell >= size) { next_empty_cell -= size; }
535
538
kv_cell & cell = cells[next_empty_cell];
@@ -543,8 +546,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
543
546
544
547
// gather and re-order
545
548
for (uint32_t s = 0 ; s < n_seqs; ++s) {
549
+ const uint32_t i = s*n_seq_tokens;
546
550
const int32_t dst_id = s + min;
547
- const int32_t src_id = cells[ubatch.seq_id [s*n_seq_tokens ][0 ]].tail ;
551
+ const int32_t src_id = cells[ubatch.seq_id [i ][0 ]].tail ;
548
552
if (dst_id != src_id) {
549
553
kv_cell & dst_cell = cells[dst_id];
550
554
kv_cell & src_cell = cells[src_id];
@@ -554,8 +558,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
554
558
std::swap (dst_cell.seq_id , src_cell.seq_id );
555
559
556
560
// swap tails
557
- for (uint32_t i = 0 ; i < size; ++i ) {
558
- int32_t & tail = cells[i ].tail ;
561
+ for (uint32_t j = 0 ; j < size; ++j ) {
562
+ int32_t & tail = cells[j ].tail ;
559
563
if (tail == src_id) {
560
564
tail = dst_id;
561
565
} else if (tail == dst_id) {
@@ -567,20 +571,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
567
571
568
572
// update the pos of the used seqs
569
573
for (uint32_t s = 0 ; s < n_seqs; ++s) {
570
- const llama_pos last_pos = ubatch.pos [s*n_seq_tokens + n_seq_tokens - 1 ];
574
+ const uint32_t i = s*n_seq_tokens;
575
+ const llama_pos last_pos = ubatch.pos [i + n_seq_tokens - 1 ];
571
576
const int32_t cell_id = s + min;
572
577
kv_cell & cell = cells[cell_id];
573
578
574
579
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
575
580
// What should happen when the pos backtracks or skips a value?
576
581
// Clearing the state mid-batch would require special-casing which isn't done.
577
582
LLAMA_LOG_WARN (" %s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n " ,
578
- __func__, last_pos, cell.pos , ubatch.seq_id [s*n_seq_tokens ][0 ], n_seq_tokens);
583
+ __func__, last_pos, cell.pos , ubatch.seq_id [i ][0 ], n_seq_tokens);
579
584
}
580
585
cell.pos = last_pos;
581
586
cell.seq_id .clear ();
582
- for (int32_t j = 0 ; j < ubatch.n_seq_id [s*n_seq_tokens ]; ++j) {
583
- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens ][j];
587
+ for (int32_t j = 0 ; j < ubatch.n_seq_id [i ]; ++j) {
588
+ const llama_seq_id seq_id = ubatch.seq_id [i ][j];
584
589
cell.seq_id .insert (seq_id);
585
590
cells[seq_id].tail = cell_id;
586
591
}
0 commit comments