Skip to content

Commit 6a50f45

Browse files
committed
cont : more consistent indexing in recurrent cache
ggml-ci
1 parent 28dec76 commit 6a50f45

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,8 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
422422
}
423423

424424
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
425-
const uint32_t n_seqs = ubatch.n_seqs;
426-
427425
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
426+
const uint32_t n_seqs = ubatch.n_seqs;
428427

429428
// if we have enough unused cells before the current head ->
430429
// better to start searching from the beginning of the cache, hoping to fill it
@@ -444,9 +443,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
444443

445444
// everything should fit if all seq_ids are smaller than the max
446445
for (uint32_t s = 0; s < n_seqs; ++s) {
447-
const uint32_t n_seq_id = ubatch.n_seq_id[s*n_seq_tokens];
446+
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
447+
const uint32_t n_seq_id = ubatch.n_seq_id[i];
448+
448449
for (uint32_t j = 0; j < n_seq_id; ++j) {
449-
const llama_seq_id seq_id = ubatch.seq_id[s*n_seq_tokens][j];
450+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
450451

451452
if (seq_id < 0 || (uint32_t) seq_id >= size) {
452453
// too big seq_id
@@ -505,7 +506,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
505506

506507
// find usable cell range
507508
for (uint32_t s = 0; s < n_seqs; ++s) {
508-
const llama_seq_id seq_id = ubatch.seq_id[s*n_seq_tokens][0];
509+
const uint32_t i = s*n_seq_tokens;
510+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
511+
509512
kv_cell & seq_meta = cells[seq_id];
510513
bool has_cell = false;
511514
if (seq_meta.tail >= 0) {
@@ -529,7 +532,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
529532
seq_meta.tail = next_empty_cell;
530533
// find next empty cell
531534
if (s + 1 < n_seqs) {
532-
for (uint32_t i = 0; i < size; ++i) {
535+
for (uint32_t j = 0; j < size; ++j) {
533536
next_empty_cell += 1;
534537
if (next_empty_cell >= size) { next_empty_cell -= size; }
535538
kv_cell & cell = cells[next_empty_cell];
@@ -543,8 +546,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
543546

544547
// gather and re-order
545548
for (uint32_t s = 0; s < n_seqs; ++s) {
549+
const uint32_t i = s*n_seq_tokens;
546550
const int32_t dst_id = s + min;
547-
const int32_t src_id = cells[ubatch.seq_id[s*n_seq_tokens][0]].tail;
551+
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
548552
if (dst_id != src_id) {
549553
kv_cell & dst_cell = cells[dst_id];
550554
kv_cell & src_cell = cells[src_id];
@@ -554,8 +558,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
554558
std::swap(dst_cell.seq_id, src_cell.seq_id);
555559

556560
// swap tails
557-
for (uint32_t i = 0; i < size; ++i) {
558-
int32_t & tail = cells[i].tail;
561+
for (uint32_t j = 0; j < size; ++j) {
562+
int32_t & tail = cells[j].tail;
559563
if (tail == src_id) {
560564
tail = dst_id;
561565
} else if (tail == dst_id) {
@@ -567,20 +571,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
567571

568572
// update the pos of the used seqs
569573
for (uint32_t s = 0; s < n_seqs; ++s) {
570-
const llama_pos last_pos = ubatch.pos[s*n_seq_tokens + n_seq_tokens - 1];
574+
const uint32_t i = s*n_seq_tokens;
575+
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
571576
const int32_t cell_id = s + min;
572577
kv_cell & cell = cells[cell_id];
573578

574579
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
575580
// What should happen when the pos backtracks or skips a value?
576581
// Clearing the state mid-batch would require special-casing which isn't done.
577582
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
578-
__func__, last_pos, cell.pos, ubatch.seq_id[s*n_seq_tokens][0], n_seq_tokens);
583+
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
579584
}
580585
cell.pos = last_pos;
581586
cell.seq_id.clear();
582-
for (int32_t j = 0; j < ubatch.n_seq_id[s*n_seq_tokens]; ++j) {
583-
const llama_seq_id seq_id = ubatch.seq_id[s*n_seq_tokens][j];
587+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
588+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
584589
cell.seq_id.insert(seq_id);
585590
cells[seq_id].tail = cell_id;
586591
}

0 commit comments

Comments
 (0)