Skip to content

Commit 732d0ed

Browse files
ggerganovMinh141120
authored andcommitted
batch : add n_used count (ggml-org#14512)
ggml-ci
1 parent b0fa27f commit 732d0ed

File tree

6 files changed

+34
-1
lines changed

6 files changed

+34
-1
lines changed

src/llama-batch.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
411411
return n_outputs;
412412
}
413413

414+
uint32_t llama_batch_allocr::get_n_used() const {
415+
return n_used;
416+
}
417+
414418
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
415419
return out_ids;
416420
}
@@ -426,6 +430,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
426430
void llama_batch_allocr::split_reset() {
427431
out_ids.clear();
428432

433+
n_used = 0;
434+
429435
used.clear();
430436
used.resize(get_n_tokens(), false);
431437

@@ -450,6 +456,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
450456
idxs.push_back(cur_idx);
451457

452458
used[cur_idx] = true;
459+
++n_used;
453460

454461
++cur_idx;
455462

@@ -535,6 +542,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
535542
idxs_per_seq[s].push_back(idx);
536543

537544
used[idx] = true;
545+
++n_used;
538546

539547
++cur_idx[s];
540548
}
@@ -576,6 +584,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
576584
idxs.push_back(cur_idx);
577585

578586
used[cur_idx] = true;
587+
++n_used;
579588

580589
if (idxs.size() >= n_ubatch) {
581590
break;

src/llama-batch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class llama_batch_allocr {
5656

5757
uint32_t get_n_tokens() const;
5858
uint32_t get_n_outputs() const;
59+
uint32_t get_n_used() const;
5960

6061
// the array of output indices in the order they were encountered during the ubatch splitting
6162
std::vector<int32_t> & get_out_ids();
@@ -127,6 +128,8 @@ class llama_batch_allocr {
127128
// batch indices of the output
128129
std::vector<int32_t> out_ids;
129130

131+
uint32_t n_used;
132+
130133
// used[i] indicates if token i has already been used in a previous ubatch
131134
std::vector<bool> used;
132135

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
117+
// failed to find a suitable split
118+
break;
119+
}
120+
116121
auto sinfos_base = kv_base->prepare(ubatches);
117122
if (sinfos_base.empty()) {
118123
break;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144149
ubatches.push_back(std::move(ubatch)); // NOLINT
145150
}
146151

152+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
153+
// failed to find a suitable split
154+
break;
155+
}
156+
147157
auto sinfos_base = kv_base->prepare(ubatches);
148158
if (sinfos_base.empty()) {
149159
break;

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353353
ubatches.push_back(std::move(ubatch)); // NOLINT
354354
}
355355

356+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
357+
// failed to find a suitable split
358+
break;
359+
}
360+
356361
auto sinfos = prepare(ubatches);
357362
if (sinfos.empty()) {
358363
break;

src/llama-memory-hybrid.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
8080
ubatches.push_back(std::move(ubatch)); // NOLINT
8181
}
8282

83+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
84+
// failed to find a suitable split
85+
break;
86+
}
87+
8388
// prepare the recurrent batches first
8489
if (!mem_recr->prepare(ubatches)) {
8590
// TODO: will the recurrent cache be in an undefined context at this point?

src/llama-memory-recurrent.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377377
ubatch = balloc.split_equal(n_ubatch);
378378
}
379379

380-
if (ubatch.n_tokens == 0) {
380+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
381+
// failed to find a suitable split
381382
break;
382383
}
383384

0 commit comments

Comments
 (0)