diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4b90dac7a327c..a1b5b1a272cc0 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -377,14 +377,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & ubatch = balloc.split_equal(n_ubatch, false); } - if (balloc.get_n_used() < balloc.get_n_tokens()) { - // failed to find a suitable split + if (ubatch.n_tokens == 0) { break; } ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + if (!prepare(ubatches)) { break; }