@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
363
363
}
364
364
365
365
llama_memory_context_ptr llama_memory_recurrent::init_batch (llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366
- std::vector<llama_ubatch> ubatches;
366
+ do {
367
+ balloc.split_reset ();
367
368
368
- while (true ) {
369
- llama_ubatch ubatch;
369
+ std::vector<llama_ubatch> ubatches;
370
+ while (true ) {
371
+ llama_ubatch ubatch;
370
372
371
- if (embd_all) {
372
- // if all tokens are output, split by sequence
373
- ubatch = balloc.split_seq (n_ubatch);
374
- } else {
375
- ubatch = balloc.split_equal (n_ubatch);
373
+ if (embd_all) {
374
+ // if all tokens are output, split by sequence
375
+ ubatch = balloc.split_seq (n_ubatch);
376
+ } else {
377
+ ubatch = balloc.split_equal (n_ubatch);
378
+ }
379
+
380
+ if (ubatch.n_tokens == 0 ) {
381
+ break ;
382
+ }
383
+
384
+ ubatches.push_back (std::move (ubatch)); // NOLINT
376
385
}
377
386
378
- if (ubatch. n_tokens == 0 ) {
387
+ if (! prepare (ubatches) ) {
379
388
break ;
380
389
}
381
390
382
- ubatches.push_back (std::move (ubatch)); // NOLINT
383
- }
384
-
385
- if (!prepare (ubatches)) {
386
- return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387
- }
391
+ return std::make_unique<llama_memory_recurrent_context>(this , std::move (ubatches));
392
+ } while (false );
388
393
389
- return std::make_unique<llama_memory_recurrent_context>(this , std::move (ubatches) );
394
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
390
395
}
391
396
392
397
llama_memory_context_ptr llama_memory_recurrent::init_full () {
0 commit comments