File tree Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Original file line number Diff line number Diff line change @@ -411,6 +411,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
411
411
return n_outputs;
412
412
}
413
413
414
+ uint32_t llama_batch_allocr::get_n_used () const {
415
+ return n_used;
416
+ }
417
+
414
418
std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
415
419
return out_ids;
416
420
}
@@ -426,6 +430,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
426
430
void llama_batch_allocr::split_reset () {
427
431
out_ids.clear ();
428
432
433
+ n_used = 0 ;
434
+
429
435
used.clear ();
430
436
used.resize (get_n_tokens (), false );
431
437
@@ -450,6 +456,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
450
456
idxs.push_back (cur_idx);
451
457
452
458
used[cur_idx] = true ;
459
+ ++n_used;
453
460
454
461
++cur_idx;
455
462
@@ -535,6 +542,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
535
542
idxs_per_seq[s].push_back (idx);
536
543
537
544
used[idx] = true ;
545
+ ++n_used;
538
546
539
547
++cur_idx[s];
540
548
}
@@ -576,6 +584,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
576
584
idxs.push_back (cur_idx);
577
585
578
586
used[cur_idx] = true ;
587
+ ++n_used;
579
588
580
589
if (idxs.size () >= n_ubatch) {
581
590
break ;
Original file line number Diff line number Diff line change @@ -56,6 +56,7 @@ class llama_batch_allocr {
56
56
57
57
uint32_t get_n_tokens () const ;
58
58
uint32_t get_n_outputs () const ;
59
+ uint32_t get_n_used () const ;
59
60
60
61
// the array of output indices in the order they were encountered during the ubatch splitting
61
62
std::vector<int32_t > & get_out_ids ();
@@ -127,6 +128,8 @@ class llama_batch_allocr {
127
128
// batch indices of the output
128
129
std::vector<int32_t > out_ids;
129
130
131
+ uint32_t n_used;
132
+
130
133
// used[i] indicates if token i has already been used in a previous ubatch
131
134
std::vector<bool > used;
132
135
Original file line number Diff line number Diff line change @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113
113
ubatches.push_back (std::move (ubatch)); // NOLINT
114
114
}
115
115
116
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
117
+ // failed to find a suitable split
118
+ break ;
119
+ }
120
+
116
121
auto sinfos_base = kv_base->prepare (ubatches);
117
122
if (sinfos_base.empty ()) {
118
123
break ;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144
149
ubatches.push_back (std::move (ubatch)); // NOLINT
145
150
}
146
151
152
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
153
+ // failed to find a suitable split
154
+ break ;
155
+ }
156
+
147
157
auto sinfos_base = kv_base->prepare (ubatches);
148
158
if (sinfos_base.empty ()) {
149
159
break ;
Original file line number Diff line number Diff line change @@ -353,6 +353,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353
353
ubatches.push_back (std::move (ubatch)); // NOLINT
354
354
}
355
355
356
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
357
+ // failed to find a suitable split
358
+ break ;
359
+ }
360
+
356
361
auto sinfos = prepare (ubatches);
357
362
if (sinfos.empty ()) {
358
363
break ;
Original file line number Diff line number Diff line change @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
80
80
ubatches.push_back (std::move (ubatch)); // NOLINT
81
81
}
82
82
83
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
84
+ // failed to find a suitable split
85
+ break ;
86
+ }
87
+
83
88
// prepare the recurrent batches first
84
89
if (!mem_recr->prepare (ubatches)) {
85
90
// TODO: will the recurrent cache be in an undefined context at this point?
Original file line number Diff line number Diff line change @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377
377
ubatch = balloc.split_equal (n_ubatch);
378
378
}
379
379
380
- if (ubatch.n_tokens == 0 ) {
380
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
381
+ // failed to find a suitable split
381
382
break ;
382
383
}
383
384
You can’t perform that action at this time.
0 commit comments