Skip to content

Commit c022829

Browse files
ggerganovqnixsynapse
authored andcommitted
kv-cache : fix split_equal handling in unified implementation (ggml-org#14130)
ggml-ci
1 parent 64a745d commit c022829

File tree

3 files changed

+145
-125
lines changed

3 files changed

+145
-125
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
862862
const auto & batch = balloc->get_batch();
863863

864864
// remember the sequence ids used during the encoding - needed for cross attention later
865+
// TODO: the seuqence indexing here is likely not correct in the general case
866+
// probably works only for split_simple
865867
cross.seq_ids_enc.resize(n_tokens);
866868
for (uint32_t i = 0; i < n_tokens; i++) {
867869
cross.seq_ids_enc[i].clear();

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -95,93 +95,77 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99-
GGML_UNUSED(embd_all);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99+
GGML_UNUSED(embd_pooled);
100100

101101
// first try simple split
102102
do {
103-
balloc.split_reset();
103+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
104104

105105
std::vector<llama_ubatch> ubatches;
106-
while (true) {
107-
auto ubatch = balloc.split_simple(n_ubatch);
108106

109-
if (ubatch.n_tokens == 0) {
110-
break;
111-
}
107+
while (sbatch.n_tokens > 0) {
108+
auto ubatch = sbatch.split_simple(n_ubatch);
112109

113-
ubatches.push_back(std::move(ubatch)); // NOLINT
110+
ubatches.push_back(ubatch);
114111
}
115112

116-
if (balloc.get_n_used() < balloc.get_n_tokens()) {
117-
// failed to find a suitable split
113+
auto heads_base = kv_base->prepare(ubatches);
114+
if (heads_base.empty()) {
118115
break;
119116
}
120117

121-
auto sinfos_base = kv_base->prepare(ubatches);
122-
if (sinfos_base.empty()) {
118+
auto heads_swa = kv_swa->prepare(ubatches);
119+
if (heads_swa.empty()) {
123120
break;
124121
}
125122

126-
auto sinfos_swa = kv_swa->prepare(ubatches);
127-
if (sinfos_swa.empty()) {
128-
break;
129-
}
123+
assert(heads_base.size() == heads_swa.size());
130124

131-
assert(sinfos_base.size() == sinfos_swa.size());
132-
133-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
134-
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
125+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
135127
} while (false);
136128

137129
// if it fails, try equal split
138130
do {
139-
balloc.split_reset();
131+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
140132

141133
std::vector<llama_ubatch> ubatches;
142-
while (true) {
143-
auto ubatch = balloc.split_equal(n_ubatch, false);
144134

145-
if (ubatch.n_tokens == 0) {
146-
break;
147-
}
135+
while (sbatch.n_tokens > 0) {
136+
auto ubatch = sbatch.split_equal(n_ubatch);
148137

149-
ubatches.push_back(std::move(ubatch)); // NOLINT
138+
ubatches.push_back(ubatch);
150139
}
151140

152-
if (balloc.get_n_used() < balloc.get_n_tokens()) {
153-
// failed to find a suitable split
141+
auto heads_base = kv_base->prepare(ubatches);
142+
if (heads_base.empty()) {
154143
break;
155144
}
156145

157-
auto sinfos_base = kv_base->prepare(ubatches);
158-
if (sinfos_base.empty()) {
146+
auto heads_swa = kv_swa->prepare(ubatches);
147+
if (heads_swa.empty()) {
159148
break;
160149
}
161150

162-
auto sinfos_swa = kv_swa->prepare(ubatches);
163-
if (sinfos_swa.empty()) {
164-
break;
165-
}
166-
167-
assert(sinfos_base.size() == sinfos_swa.size());
151+
assert(heads_base.size() == heads_swa.size());
168152

169-
return std::make_unique<llama_kv_cache_unified_iswa_context>(
170-
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
153+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
154+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
171155
} while (false);
172156

173157
// TODO: if we fail again, we should attempt different splitting strategies
174158
// but to do that properly, we first have to refactor the batches to be more flexible
175159

176-
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
177161
}
178162

179-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
180-
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
163+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
164+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
181165
}
182166

183-
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
184-
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
167+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
168+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
185169
}
186170

187171
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -207,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
207191
}
208192

209193
//
210-
// llama_kv_cache_unified_iswa_context
194+
// llama_kv_cache_unified_iswa_state
211195
//
212196

213-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
197+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
198+
199+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201+
state_base = kv->get_base()->init_full();
202+
state_swa = kv->get_swa ()->init_full();
214203

215-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
216-
llama_kv_cache_unified_iswa * kv) :
217-
ctx_base(kv->get_base()->init_full()),
218-
ctx_swa (kv->get_swa ()->init_full()),
219-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
204+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
220205
}
221206

222-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
207+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
223208
llama_kv_cache_unified_iswa * kv,
224209
llama_context * lctx,
225-
bool optimize) :
226-
ctx_base(kv->get_base()->init_update(lctx, optimize)),
227-
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
228-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
210+
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211+
state_base = kv->get_base()->init_update(lctx, optimize);
212+
state_swa = kv->get_swa ()->init_update(lctx, optimize);
213+
214+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
229215
}
230216

231-
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
217+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
232218
llama_kv_cache_unified_iswa * kv,
233-
slot_info_vec_t sinfos_base,
234-
slot_info_vec_t sinfos_swa,
235-
std::vector<llama_ubatch> ubatches) :
236-
ubatches(std::move(ubatches)),
219+
llama_sbatch sbatch,
220+
std::vector<uint32_t> heads_base,
221+
std::vector<uint32_t> heads_swa,
222+
std::vector<llama_ubatch> ubatches)
223+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
224+
sbatch(std::move(sbatch)),
225+
ubatches(std::move(ubatches)) {
237226
// note: here we copy the ubatches. not sure if this is ideal
238-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
239-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
240-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
227+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
228+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
229+
230+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
241231
}
242232

243-
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
233+
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
244234

245-
bool llama_kv_cache_unified_iswa_context::next() {
235+
bool llama_kv_cache_unified_iswa_state::next() {
246236
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247237

248-
ctx_base->next();
249-
ctx_swa ->next();
238+
state_base->next();
239+
state_swa ->next();
250240

251241
if (++i_next >= ubatches.size()) {
252242
return false;
@@ -255,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
255245
return true;
256246
}
257247

258-
bool llama_kv_cache_unified_iswa_context::apply() {
259-
assert(!llama_memory_status_is_fail(status));
248+
bool llama_kv_cache_unified_iswa_state::apply() {
249+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
260250

261251
bool res = true;
262252

263-
res = res & ctx_base->apply();
264-
res = res & ctx_swa ->apply();
253+
res = res & state_base->apply();
254+
res = res & state_swa ->apply();
265255

266256
return res;
267257
}
268258

269-
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
259+
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
260+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
261+
262+
return sbatch.out_ids;
263+
}
264+
265+
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
270266
return status;
271267
}
272268

273-
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
269+
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
274270
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
275271

276272
return ubatches[i_next];
277273
}
278274

279-
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
275+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
280276
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
281277

282-
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
278+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
283279
}
284280

285-
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
281+
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
286282
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
287283

288-
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
284+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
289285
}

0 commit comments

Comments
 (0)