Skip to content

Commit 13df0aa

Browse files
ggerganovMinh141120
authored andcommitted
batch : add optional for sequential equal split (ggml-org#14511)
ggml-ci
1 parent 567b16c commit 13df0aa

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

src/llama-batch.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166

167167
// note: tracking the other way around is not necessary for now
168168
//seq_cpl[s0][s1] = true;
169+
170+
has_cpl = true;
169171
}
170172
}
171173
}
@@ -472,9 +474,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
472474
return ubatch_add(idxs, idxs.size(), false);
473475
}
474476

475-
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
477+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
478+
if (sequential && has_cpl) {
479+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
480+
481+
return {};
482+
}
483+
476484
std::vector<seq_set_t> cur_seq_set;
477485

486+
llama_seq_id last_seq_id = -1;
487+
478488
// determine the non-overlapping sequence sets participating in this ubatch
479489
for (int32_t i = 0; i < batch.n_tokens; ++i) {
480490
if (used[i]) {
@@ -491,9 +501,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
491501
}
492502
}
493503

504+
// accept only increasing sequence ids
505+
if (sequential) {
506+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
507+
}
508+
494509
if (add) {
495510
cur_seq_set.push_back(seq_set[i]);
496511

512+
last_seq_id = batch.seq_id[i][0];
513+
497514
if (cur_seq_set.size() > n_ubatch) {
498515
break;
499516
}

src/llama-batch.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class llama_batch_allocr {
7272
llama_ubatch split_simple(uint32_t n_ubatch);
7373

7474
// make ubatches of equal-length sequences sets
75-
llama_ubatch split_equal(uint32_t n_ubatch);
75+
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
76+
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
7677

7778
// sequence-set-wise split - each ubatch contains a single sequence-set
7879
llama_ubatch split_seq(uint32_t n_ubatch);
@@ -115,6 +116,9 @@ class llama_batch_allocr {
115116
using pos_set_t = std::set<llama_pos>;
116117
using seq_cpl_t = std::vector<bool>;
117118

119+
// helper flag to quickly determine if there are any coupled sequences in the batch
120+
bool has_cpl;
121+
118122
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
119123
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
120124

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140140

141141
std::vector<llama_ubatch> ubatches;
142142
while (true) {
143-
auto ubatch = balloc.split_equal(n_ubatch);
143+
auto ubatch = balloc.split_equal(n_ubatch, false);
144144

145145
if (ubatch.n_tokens == 0) {
146146
break;

src/llama-memory-hybrid.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
7070
// if all tokens are output, split by sequence
7171
ubatch = balloc.split_seq(n_ubatch);
7272
} else {
73-
ubatch = balloc.split_equal(n_ubatch);
73+
ubatch = balloc.split_equal(n_ubatch, false);
7474
}
7575

7676
if (ubatch.n_tokens == 0) {

src/llama-memory-recurrent.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374374
// if all tokens are output, split by sequence
375375
ubatch = balloc.split_seq(n_ubatch);
376376
} else {
377-
ubatch = balloc.split_equal(n_ubatch);
377+
ubatch = balloc.split_equal(n_ubatch, false);
378378
}
379379

380380
if (balloc.get_n_used() < balloc.get_n_tokens()) {

0 commit comments

Comments
 (0)