Skip to content

Commit 4a594d7

Browse files
committed
batch : add n_used count
ggml-ci
1 parent 4acfc25 commit 4a594d7

File tree

6 files changed

+34
-1
lines changed

6 files changed

+34
-1
lines changed

src/llama-batch.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
407407
return n_outputs;
408408
}
409409

410+
uint32_t llama_batch_allocr::get_n_used() const {
411+
return n_used;
412+
}
413+
410414
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
411415
return out_ids;
412416
}
@@ -422,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
422426
void llama_batch_allocr::split_reset() {
423427
out_ids.clear();
424428

429+
n_used = 0;
430+
425431
used.clear();
426432
used.resize(get_n_tokens(), false);
427433

@@ -446,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
446452
idxs.push_back(cur_idx);
447453

448454
used[cur_idx] = true;
455+
++n_used;
449456

450457
++cur_idx;
451458

@@ -546,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential)
546553
idxs_per_seq[s].push_back(idx);
547554

548555
used[idx] = true;
556+
++n_used;
549557

550558
++cur_idx[s];
551559
}
@@ -587,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
587595
idxs.push_back(cur_idx);
588596

589597
used[cur_idx] = true;
598+
++n_used;
590599

591600
if (idxs.size() >= n_ubatch) {
592601
break;

src/llama-batch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class llama_batch_allocr {
5454

5555
uint32_t get_n_tokens() const;
5656
uint32_t get_n_outputs() const;
57+
uint32_t get_n_used() const;
5758

5859
// the array of output indices in the order they were encountered during the ubatch splitting
5960
std::vector<int32_t> & get_out_ids();
@@ -129,6 +130,8 @@ class llama_batch_allocr {
129130
// batch indices of the output
130131
std::vector<int32_t> out_ids;
131132

133+
uint32_t n_used;
134+
132135
// used[i] indicates if token i has already been used in a previous ubatch
133136
std::vector<bool> used;
134137

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
117+
// failed to find a suitable split
118+
break;
119+
}
120+
116121
auto sinfos_base = kv_base->prepare(ubatches);
117122
if (sinfos_base.empty()) {
118123
break;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144149
ubatches.push_back(std::move(ubatch)); // NOLINT
145150
}
146151

152+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
153+
// failed to find a suitable split
154+
break;
155+
}
156+
147157
auto sinfos_base = kv_base->prepare(ubatches);
148158
if (sinfos_base.empty()) {
149159
break;

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
360360
ubatches.push_back(std::move(ubatch)); // NOLINT
361361
}
362362

363+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
364+
// failed to find a suitable split
365+
break;
366+
}
367+
363368
auto sinfos = prepare(ubatches);
364369
if (sinfos.empty()) {
365370
break;

src/llama-memory-hybrid.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
8080
ubatches.push_back(std::move(ubatch)); // NOLINT
8181
}
8282

83+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
84+
// failed to find a suitable split
85+
break;
86+
}
87+
8388
// prepare the recurrent batches first
8489
if (!mem_recr->prepare(ubatches)) {
8590
// TODO: will the recurrent cache be in an undefined context at this point?

src/llama-memory-recurrent.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377377
ubatch = balloc.split_equal(n_ubatch, false);
378378
}
379379

380-
if (ubatch.n_tokens == 0) {
380+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
381+
// failed to find a suitable split
381382
break;
382383
}
383384

0 commit comments

Comments
 (0)