File tree Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Original file line number Diff line number Diff line change @@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405
405
return n_outputs;
406
406
}
407
407
408
+ uint32_t llama_batch_allocr::get_n_used () const {
409
+ return n_used;
410
+ }
411
+
408
412
std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
409
413
return out_ids;
410
414
}
@@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420
424
void llama_batch_allocr::split_reset () {
421
425
out_ids.clear ();
422
426
427
+ n_used = 0 ;
428
+
423
429
used.clear ();
424
430
used.resize (get_n_tokens (), false );
425
431
@@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444
450
idxs.push_back (cur_idx);
445
451
446
452
used[cur_idx] = true ;
453
+ ++n_used;
447
454
448
455
++cur_idx;
449
456
@@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529
536
idxs_per_seq[s].push_back (idx);
530
537
531
538
used[idx] = true ;
539
+ ++n_used;
532
540
533
541
++cur_idx[s];
534
542
}
@@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570
578
idxs.push_back (cur_idx);
571
579
572
580
used[cur_idx] = true ;
581
+ ++n_used;
573
582
574
583
if (idxs.size () >= n_ubatch) {
575
584
break ;
Original file line number Diff line number Diff line change @@ -54,6 +54,7 @@ class llama_batch_allocr {
54
54
55
55
uint32_t get_n_tokens () const ;
56
56
uint32_t get_n_outputs () const ;
57
+ uint32_t get_n_used () const ;
57
58
58
59
// the array of output indices in the order they were encountered during the ubatch splitting
59
60
std::vector<int32_t > & get_out_ids ();
@@ -125,6 +126,8 @@ class llama_batch_allocr {
125
126
// batch indices of the output
126
127
std::vector<int32_t > out_ids;
127
128
129
+ uint32_t n_used;
130
+
128
131
// used[i] indicates if token i has already been used in a previous ubatch
129
132
std::vector<bool > used;
130
133
Original file line number Diff line number Diff line change @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113
113
ubatches.push_back (std::move (ubatch)); // NOLINT
114
114
}
115
115
116
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
117
+ // failed to find a suitable split
118
+ break ;
119
+ }
120
+
116
121
auto sinfos_base = kv_base->prepare (ubatches);
117
122
if (sinfos_base.empty ()) {
118
123
break ;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144
149
ubatches.push_back (std::move (ubatch)); // NOLINT
145
150
}
146
151
152
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
153
+ // failed to find a suitable split
154
+ break ;
155
+ }
156
+
147
157
auto sinfos_base = kv_base->prepare (ubatches);
148
158
if (sinfos_base.empty ()) {
149
159
break ;
Original file line number Diff line number Diff line change @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
360
360
ubatches.push_back (std::move (ubatch)); // NOLINT
361
361
}
362
362
363
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
364
+ // failed to find a suitable split
365
+ break ;
366
+ }
367
+
363
368
auto sinfos = prepare (ubatches);
364
369
if (sinfos.empty ()) {
365
370
break ;
Original file line number Diff line number Diff line change @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
80
80
ubatches.push_back (std::move (ubatch)); // NOLINT
81
81
}
82
82
83
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
84
+ // failed to find a suitable split
85
+ break ;
86
+ }
87
+
83
88
// prepare the recurrent batches first
84
89
if (!mem_recr->prepare (ubatches)) {
85
90
// TODO: will the recurrent cache be in an undefined context at this point?
Original file line number Diff line number Diff line change @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377
377
ubatch = balloc.split_equal (n_ubatch);
378
378
}
379
379
380
- if (ubatch.n_tokens == 0 ) {
380
+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
381
+ // failed to find a suitable split
381
382
break ;
382
383
}
383
384
You can’t perform that action at this time.
0 commit comments