Skip to content

Commit 2ff3354

Browse files
committed
memory : fix broken batch splits for recurrent cache
Splits producing more than one ubatch per batch for recurrent models were broken with #14512. This fixes it by moving the completeness check after the ubatch split loop.
1 parent e1a7059 commit 2ff3354

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/llama-memory-recurrent.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,14 +377,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377377
ubatch = balloc.split_equal(n_ubatch, false);
378378
}
379379

380-
if (balloc.get_n_used() < balloc.get_n_tokens()) {
381-
// failed to find a suitable split
380+
if (ubatch.n_tokens == 0) {
382381
break;
383382
}
384383

385384
ubatches.push_back(std::move(ubatch)); // NOLINT
386385
}
387386

387+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
388+
// failed to find a suitable split
389+
break;
390+
}
391+
388392
if (!prepare(ubatches)) {
389393
break;
390394
}

0 commit comments

Comments
 (0)