Skip to content

Commit 16faf81

Browse files
ggerganovMinh141120
authored andcommitted
kv-cache : fix split_equal handling in unified implementation (ggml-org#14130)
ggml-ci
1 parent aed2b29 commit 16faf81

File tree

3 files changed

+69
-60
lines changed

3 files changed

+69
-60
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
862862
const auto & batch = balloc->get_batch();
863863

864864
// remember the sequence ids used during the encoding - needed for cross attention later
865+
// TODO: the seuqence indexing here is likely not correct in the general case
866+
// probably works only for split_simple
865867
cross.seq_ids_enc.resize(n_tokens);
866868
for (uint32_t i = 0; i < n_tokens; i++) {
867869
cross.seq_ids_enc[i].clear();

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,14 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100100

101101
// first try simple split
102102
do {
103-
balloc.split_reset();
103+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
104104

105105
std::vector<llama_ubatch> ubatches;
106-
while (true) {
107-
auto ubatch = balloc.split_simple(n_ubatch);
108106

109-
if (ubatch.n_tokens == 0) {
110-
break;
111-
}
107+
while (sbatch.n_tokens > 0) {
108+
auto ubatch = sbatch.split_simple(n_ubatch);
112109

113-
ubatches.push_back(std::move(ubatch)); // NOLINT
110+
ubatches.push_back(ubatch);
114111
}
115112

116113
auto heads_base = kv_base->prepare(ubatches);
@@ -125,23 +122,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
125122

126123
assert(heads_base.size() == heads_swa.size());
127124

128-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
125+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130127
} while (false);
131128

132129
// if it fails, try equal split
133130
do {
134-
balloc.split_reset();
131+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
135132

136133
std::vector<llama_ubatch> ubatches;
137-
while (true) {
138-
auto ubatch = balloc.split_equal(n_ubatch);
139134

140-
if (ubatch.n_tokens == 0) {
141-
break;
142-
}
135+
while (sbatch.n_tokens > 0) {
136+
auto ubatch = sbatch.split_equal(n_ubatch);
143137

144-
ubatches.push_back(std::move(ubatch)); // NOLINT
138+
ubatches.push_back(ubatch);
145139
}
146140

147141
auto heads_base = kv_base->prepare(ubatches);
@@ -156,14 +150,14 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
156150

157151
assert(heads_base.size() == heads_swa.size());
158152

159-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
153+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
154+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161155
} while (false);
162156

163157
// TODO: if we fail again, we should attempt different splitting strategies
164158
// but to do that properly, we first have to refactor the batches to be more flexible
165159

166-
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167161
}
168162

169163
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
@@ -200,13 +194,14 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
200194
// llama_kv_cache_unified_iswa_context
201195
//
202196

203-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
197+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
204198

205-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
206-
llama_kv_cache_unified_iswa * kv) :
207-
ctx_base(kv->get_base()->init_full()),
208-
ctx_swa (kv->get_swa ()->init_full()),
209-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
199+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201+
state_base = kv->get_base()->init_full();
202+
state_swa = kv->get_swa ()->init_full();
203+
204+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
210205
}
211206

212207
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(

src/llama-kv-cache-unified.cpp

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -343,29 +343,23 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
343343
GGML_UNUSED(embd_all);
344344

345345
do {
346-
balloc.split_reset();
346+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
347347

348348
std::vector<llama_ubatch> ubatches;
349-
while (true) {
350-
auto ubatch = balloc.split_simple(n_ubatch);
351-
352-
if (ubatch.n_tokens == 0) {
353-
break;
354-
}
355-
356-
ubatches.push_back(std::move(ubatch)); // NOLINT
349+
while (sbatch.n_tokens > 0) {
350+
ubatches.push_back(sbatch.split_simple(n_ubatch));
357351
}
358352

359353
auto heads = prepare(ubatches);
360354
if (heads.empty()) {
361355
break;
362356
}
363357

364-
return std::make_unique<llama_kv_cache_unified_context>(
365-
this, std::move(heads), std::move(ubatches));
358+
return std::make_unique<llama_kv_cache_unified_state>(
359+
this, std::move(sbatch), std::move(heads), std::move(ubatches));
366360
} while (false);
367361

368-
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
362+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
369363
}
370364

371365
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
@@ -559,7 +553,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
559553
}
560554

561555
if (debug > 0) {
562-
LLAMA_LOG_CONT("\n");
563556
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
564557

565558
if ((debug == 2 && n_swa > 0) || debug > 2) {
@@ -685,29 +678,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
685678
}
686679

687680
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
681+
if (debug > 0) {
682+
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
683+
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
684+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
685+
}
686+
688687
// keep track of the max sequence position that we would overwrite with this ubatch
689688
// for non-SWA cache, this would be always empty
690689
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
691690
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
692691
seq_pos_max_rm[s] = -1;
693692
}
694693

695-
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
696-
if (!cells.is_empty(head_cur + i)) {
697-
assert(cells.seq_count(head_cur + i) == 1);
694+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
695+
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
696+
const uint32_t idx = s*ubatch.n_seq_tokens + j;
698697

699-
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
700-
const llama_pos pos = cells.pos_get(head_cur + i);
698+
if (!cells.is_empty(head_cur + idx)) {
699+
assert(cells.seq_count(head_cur + idx) == 1);
701700

702-
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
701+
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
702+
const llama_pos pos = cells.pos_get(head_cur + idx);
703703

704-
cells.rm(head_cur + i);
705-
}
704+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
706705

707-
cells.pos_set(head_cur + i, ubatch.pos[i]);
706+
cells.rm(head_cur + idx);
707+
}
708+
709+
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
708710

709-
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
710-
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
711+
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
712+
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
713+
}
711714
}
712715
}
713716

@@ -726,7 +729,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
726729
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
727730
}
728731
}
729-
730732
// move the head at the end of the slot
731733
head = head_cur + ubatch.n_tokens;
732734
}
@@ -823,11 +825,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
823825
}
824826

825827
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
826-
const uint32_t n_tokens = ubatch->n_tokens;
828+
const uint32_t n_tokens = ubatch->n_tokens;
829+
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
830+
const uint32_t n_seqs = ubatch->n_seqs;
827831

828832
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
829833
float * data = (float *) dst->data;
830834

835+
const int64_t n_kv = dst->ne[0];
831836
const int64_t n_kv = dst->ne[0];
832837

833838
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
@@ -843,10 +848,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
843848
// xxxxx-----
844849
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
845850
for (uint32_t h = 0; h < 1; ++h) {
846-
for (uint32_t i = 0; i < n_tokens; ++i) {
847-
const llama_seq_id seq_id = ubatch->seq_id[i][0];
851+
for (uint32_t s = 0; s < n_seqs; ++s) {
852+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
853+
854+
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
855+
const uint32_t idx = s*n_seq_tokens + j;
848856

849-
const llama_pos p1 = ubatch->pos[i];
857+
const llama_pos p1 = ubatch->pos[idx];
850858

851859
for (uint32_t j = 0; j < n_kv; ++j) {
852860
float f = 0.0f;
@@ -876,15 +884,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
876884
f = -INFINITY;
877885
}
878886

879-
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
887+
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
888+
}
880889
}
881890
}
882891

883892
// mask padded tokens
884893
if (data) {
885-
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
886-
for (uint32_t j = 0; j < n_kv; ++j) {
887-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
894+
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
895+
for (uint32_t i = 0; i < n_kv; ++i) {
896+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
888897
}
889898
}
890899
}
@@ -1534,9 +1543,12 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15341543

15351544
seq_rm(dest_seq_id, -1, -1);
15361545

1537-
llama_batch_allocr balloc(hparams.n_pos_per_embd());
1546+
llama_sbatch sbatch;
1547+
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
15381548

1539-
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1549+
ubatch.n_tokens = cell_count;
1550+
ubatch.n_seq_tokens = cell_count;
1551+
ubatch.n_seqs = 1;
15401552

15411553
for (uint32_t i = 0; i < cell_count; ++i) {
15421554
llama_pos pos;

0 commit comments

Comments
 (0)