Skip to content

Commit 8f542f0

Browse files
ggerganovqnixsynapse
authored andcommitted
kv-cache : fix split_equal handling in unified implementation (ggml-org#14130)
ggml-ci
1 parent 84251df commit 8f542f0

File tree

3 files changed

+133
-103
lines changed

3 files changed

+133
-103
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
862862
const auto & batch = balloc->get_batch();
863863

864864
// remember the sequence ids used during the encoding - needed for cross attention later
865+
// TODO: the seuqence indexing here is likely not correct in the general case
866+
// probably works only for split_simple
865867
cross.seq_ids_enc.resize(n_tokens);
866868
for (uint32_t i = 0; i < n_tokens; i++) {
867869
cross.seq_ids_enc[i].clear();

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,19 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99-
GGML_UNUSED(embd_all);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99+
GGML_UNUSED(embd_pooled);
100100

101101
// first try simple split
102102
do {
103-
balloc.split_reset();
103+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
104104

105105
std::vector<llama_ubatch> ubatches;
106-
while (true) {
107-
auto ubatch = balloc.split_simple(n_ubatch);
108106

109-
if (ubatch.n_tokens == 0) {
110-
break;
111-
}
107+
while (sbatch.n_tokens > 0) {
108+
auto ubatch = sbatch.split_simple(n_ubatch);
112109

113-
ubatches.push_back(std::move(ubatch)); // NOLINT
110+
ubatches.push_back(ubatch);
114111
}
115112

116113
auto heads_base = kv_base->prepare(ubatches);
@@ -125,23 +122,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
125122

126123
assert(heads_base.size() == heads_swa.size());
127124

128-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
125+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130127
} while (false);
131128

132129
// if it fails, try equal split
133130
do {
134-
balloc.split_reset();
131+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
135132

136133
std::vector<llama_ubatch> ubatches;
137-
while (true) {
138-
auto ubatch = balloc.split_equal(n_ubatch);
139134

140-
if (ubatch.n_tokens == 0) {
141-
break;
142-
}
135+
while (sbatch.n_tokens > 0) {
136+
auto ubatch = sbatch.split_equal(n_ubatch);
143137

144-
ubatches.push_back(std::move(ubatch)); // NOLINT
138+
ubatches.push_back(ubatch);
145139
}
146140

147141
auto heads_base = kv_base->prepare(ubatches);
@@ -156,22 +150,22 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
156150

157151
assert(heads_base.size() == heads_swa.size());
158152

159-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
153+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
154+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161155
} while (false);
162156

163157
// TODO: if we fail again, we should attempt different splitting strategies
164158
// but to do that properly, we first have to refactor the batches to be more flexible
165159

166-
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167161
}
168162

169-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
170-
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
163+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
164+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
171165
}
172166

173-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174-
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
167+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
168+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
175169
}
176170

177171
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -197,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197191
}
198192

199193
//
200-
// llama_kv_cache_unified_iswa_context
194+
// llama_kv_cache_unified_iswa_state
201195
//
202196

203-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
197+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
204198

205-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
206-
llama_kv_cache_unified_iswa * kv) :
207-
ctx_base(kv->get_base()->init_full()),
208-
ctx_swa (kv->get_swa ()->init_full()),
209-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
199+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201+
state_base = kv->get_base()->init_full();
202+
state_swa = kv->get_swa ()->init_full();
203+
204+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
210205
}
211206

212-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
207+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
213208
llama_kv_cache_unified_iswa * kv,
214209
llama_context * lctx,
215-
bool optimize) :
216-
ctx_base(kv->get_base()->init_update(lctx, optimize)),
217-
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
210+
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211+
state_base = kv->get_base()->init_update(lctx, optimize);
212+
state_swa = kv->get_swa ()->init_update(lctx, optimize);
213+
214+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
219215
}
220216

221-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
217+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
222218
llama_kv_cache_unified_iswa * kv,
219+
llama_sbatch sbatch,
223220
std::vector<uint32_t> heads_base,
224221
std::vector<uint32_t> heads_swa,
225-
std::vector<llama_ubatch> ubatches) :
226-
ubatches(std::move(ubatches)),
222+
std::vector<llama_ubatch> ubatches)
223+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
224+
sbatch(std::move(sbatch)),
225+
ubatches(std::move(ubatches)) {
227226
// note: here we copy the ubatches. not sure if this is ideal
228-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
227+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
228+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
229+
230+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
231231
}
232232

233-
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
233+
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
234234

235-
bool llama_kv_cache_unified_iswa_context::next() {
235+
bool llama_kv_cache_unified_iswa_state::next() {
236236
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237237

238-
ctx_base->next();
239-
ctx_swa ->next();
238+
state_base->next();
239+
state_swa ->next();
240240

241241
if (++i_next >= ubatches.size()) {
242242
return false;
@@ -245,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
245245
return true;
246246
}
247247

248-
bool llama_kv_cache_unified_iswa_context::apply() {
249-
assert(!llama_memory_status_is_fail(status));
248+
bool llama_kv_cache_unified_iswa_state::apply() {
249+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
250250

251251
bool res = true;
252252

253-
res = res & ctx_base->apply();
254-
res = res & ctx_swa ->apply();
253+
res = res & state_base->apply();
254+
res = res & state_swa ->apply();
255255

256256
return res;
257257
}
258258

259-
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
259+
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
260+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
261+
262+
return sbatch.out_ids;
263+
}
264+
265+
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
260266
return status;
261267
}
262268

263-
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
269+
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
264270
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
265271

266272
return ubatches[i_next];
267273
}
268274

269-
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
275+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
270276
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
271277

272-
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
278+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
273279
}
274280

275-
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
281+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
276282
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
277283

278-
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
284+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
279285
}

0 commit comments

Comments
 (0)