Skip to content

Commit 5eb1a88

Browse files
committed
batch : optional requirement for sequential sequence ids
ggml-ci
1 parent 6663128 commit 5eb1a88

File tree

6 files changed

+11
-8
lines changed

6 files changed

+11
-8
lines changed

src/llama-batch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
457457
return ubatch_add(idxs, idxs.size(), false);
458458
}
459459

460-
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
460+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
461461
std::vector<seq_set_t> cur_seq_set;
462462

463463
llama_seq_id last_seq_id = -1;
@@ -479,7 +479,9 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
479479
}
480480

481481
// accept only increasing sequence ids
482-
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
482+
if (sequential) {
483+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
484+
}
483485

484486
if (add) {
485487
cur_seq_set.push_back(seq_set[i]);

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class llama_batch_allocr {
6969
llama_ubatch split_simple(uint32_t n_ubatch);
7070

7171
// make ubatches of equal-length sequences sets
72-
llama_ubatch split_equal(uint32_t n_ubatch);
72+
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
73+
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
7374

7475
// sequence-set-wise split - each ubatch contains a single sequence-set
7576
llama_ubatch split_seq(uint32_t n_ubatch);

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
102102
// first try simple split
103103
do {
104104
if (n_seq_virt > 1) {
105-
// requires equal splits
105+
// requires equal splits, so we skip the simple split
106106
break;
107107
}
108108

@@ -141,7 +141,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
141141

142142
std::vector<llama_ubatch> ubatches;
143143
while (true) {
144-
auto ubatch = balloc.split_equal(n_ubatch);
144+
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
145145

146146
if (ubatch.n_tokens == 0) {
147147
break;

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
418418

419419
std::vector<llama_ubatch> ubatches;
420420
while (true) {
421-
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch);
421+
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
422422

423423
if (ubatch.n_tokens == 0) {
424424
break;

src/llama-memory-hybrid.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
7171
// if all tokens are output, split by sequence
7272
ubatch = balloc.split_seq(n_ubatch);
7373
} else {
74-
ubatch = balloc.split_equal(n_ubatch);
74+
ubatch = balloc.split_equal(n_ubatch, false);
7575
}
7676

7777
if (ubatch.n_tokens == 0) {

src/llama-memory-recurrent.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
372372
// if all tokens are output, split by sequence
373373
ubatch = balloc.split_seq(n_ubatch);
374374
} else {
375-
ubatch = balloc.split_equal(n_ubatch);
375+
ubatch = balloc.split_equal(n_ubatch, false);
376376
}
377377

378378
if (ubatch.n_tokens == 0) {

0 commit comments

Comments
 (0)