@@ -95,22 +95,19 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
95
return kv_swa->seq_pos_max (seq_id);
96
96
}
97
97
98
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch (llama_batch_allocr & balloc , uint32_t n_ubatch, bool embd_all ) {
99
- GGML_UNUSED (embd_all );
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch (const llama_batch & batch , uint32_t n_ubatch, bool embd_pooled, bool logits_all ) {
99
+ GGML_UNUSED (embd_pooled );
100
100
101
101
// first try simple split
102
102
do {
103
- balloc. split_reset ( );
103
+ auto sbatch = llama_sbatch (batch, hparams. n_embd , true , logits_all );
104
104
105
105
std::vector<llama_ubatch> ubatches;
106
- while (true ) {
107
- auto ubatch = balloc.split_simple (n_ubatch);
108
106
109
- if (ubatch.n_tokens == 0 ) {
110
- break ;
111
- }
107
+ while (sbatch.n_tokens > 0 ) {
108
+ auto ubatch = sbatch.split_simple (n_ubatch);
112
109
113
- ubatches.push_back (std::move ( ubatch)); // NOLINT
110
+ ubatches.push_back (ubatch);
114
111
}
115
112
116
113
auto heads_base = kv_base->prepare (ubatches);
@@ -125,23 +122,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
125
122
126
123
assert (heads_base.size () == heads_swa.size ());
127
124
128
- return std::make_unique<llama_kv_cache_unified_iswa_context >(
129
- this , std::move (heads_base), std::move (heads_swa), std::move (ubatches));
125
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
126
+ this , std::move (sbatch), std::move ( heads_base), std::move (heads_swa), std::move (ubatches));
130
127
} while (false );
131
128
132
129
// if it fails, try equal split
133
130
do {
134
- balloc. split_reset ( );
131
+ auto sbatch = llama_sbatch (batch, hparams. n_embd , false , logits_all );
135
132
136
133
std::vector<llama_ubatch> ubatches;
137
- while (true ) {
138
- auto ubatch = balloc.split_equal (n_ubatch);
139
134
140
- if (ubatch.n_tokens == 0 ) {
141
- break ;
142
- }
135
+ while (sbatch.n_tokens > 0 ) {
136
+ auto ubatch = sbatch.split_equal (n_ubatch);
143
137
144
- ubatches.push_back (std::move ( ubatch)); // NOLINT
138
+ ubatches.push_back (ubatch);
145
139
}
146
140
147
141
auto heads_base = kv_base->prepare (ubatches);
@@ -156,22 +150,22 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
156
150
157
151
assert (heads_base.size () == heads_swa.size ());
158
152
159
- return std::make_unique<llama_kv_cache_unified_iswa_context >(
160
- this , std::move (heads_base), std::move (heads_swa), std::move (ubatches));
153
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
154
+ this , std::move (sbatch), std::move ( heads_base), std::move (heads_swa), std::move (ubatches));
161
155
} while (false );
162
156
163
157
// TODO: if we fail again, we should attempt different splitting strategies
164
158
// but to do that properly, we first have to refactor the batches to be more flexible
165
159
166
- return std::make_unique<llama_kv_cache_unified_iswa_context >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167
161
}
168
162
169
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full () {
170
- return std::make_unique<llama_kv_cache_unified_iswa_context >(this );
163
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
164
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this );
171
165
}
172
166
173
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
174
- return std::make_unique<llama_kv_cache_unified_iswa_context >(this , lctx, optimize);
167
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
168
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this , lctx, optimize);
175
169
}
176
170
177
171
bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -197,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197
191
}
198
192
199
193
//
200
- // llama_kv_cache_unified_iswa_context
194
+ // llama_kv_cache_unified_iswa_state
201
195
//
202
196
203
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (llama_memory_status status) : status(status) {}
197
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
204
198
205
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
206
- llama_kv_cache_unified_iswa * kv) :
207
- ctx_base(kv->get_base ()->init_full()),
208
- ctx_swa (kv->get_swa ()->init_full()),
209
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
199
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
200
+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201
+ state_base = kv->get_base ()->init_full ();
202
+ state_swa = kv->get_swa ()->init_full ();
203
+
204
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
210
205
}
211
206
212
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
207
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
213
208
llama_kv_cache_unified_iswa * kv,
214
209
llama_context * lctx,
215
- bool optimize) :
216
- ctx_base(kv->get_base ()->init_update(lctx, optimize)),
217
- ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
210
+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211
+ state_base = kv->get_base ()->init_update (lctx, optimize);
212
+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
213
+
214
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
219
215
}
220
216
221
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
217
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
222
218
llama_kv_cache_unified_iswa * kv,
219
+ llama_sbatch sbatch,
223
220
std::vector<uint32_t > heads_base,
224
221
std::vector<uint32_t > heads_swa,
225
- std::vector<llama_ubatch> ubatches) :
226
- ubatches(std::move(ubatches)),
222
+ std::vector<llama_ubatch> ubatches)
223
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
224
+ sbatch(std::move(sbatch)),
225
+ ubatches(std::move(ubatches)) {
227
226
// note: here we copy the ubatches. not sure if this is ideal
228
- ctx_base(new llama_kv_cache_unified_context(kv->get_base (), std::move(heads_base), this->ubatches)),
229
- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
227
+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
228
+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
229
+
230
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
231
231
}
232
232
233
- llama_kv_cache_unified_iswa_context :: ~llama_kv_cache_unified_iswa_context () = default ;
233
+ llama_kv_cache_unified_iswa_state :: ~llama_kv_cache_unified_iswa_state () = default ;
234
234
235
- bool llama_kv_cache_unified_iswa_context ::next () {
235
+ bool llama_kv_cache_unified_iswa_state ::next () {
236
236
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
237
237
238
- ctx_base ->next ();
239
- ctx_swa ->next ();
238
+ state_base ->next ();
239
+ state_swa ->next ();
240
240
241
241
if (++i_next >= ubatches.size ()) {
242
242
return false ;
@@ -245,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
245
245
return true ;
246
246
}
247
247
248
- bool llama_kv_cache_unified_iswa_context ::apply () {
249
- assert (! llama_memory_status_is_fail ( status) );
248
+ bool llama_kv_cache_unified_iswa_state ::apply () {
249
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS );
250
250
251
251
bool res = true ;
252
252
253
- res = res & ctx_base ->apply ();
254
- res = res & ctx_swa ->apply ();
253
+ res = res & state_base ->apply ();
254
+ res = res & state_swa ->apply ();
255
255
256
256
return res;
257
257
}
258
258
259
- llama_memory_status llama_kv_cache_unified_iswa_context::get_status () const {
259
+ std::vector<int64_t > & llama_kv_cache_unified_iswa_state::out_ids () {
260
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
261
+
262
+ return sbatch.out_ids ;
263
+ }
264
+
265
+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status () const {
260
266
return status;
261
267
}
262
268
263
- const llama_ubatch & llama_kv_cache_unified_iswa_context ::get_ubatch () const {
269
+ const llama_ubatch & llama_kv_cache_unified_iswa_state ::get_ubatch () const {
264
270
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
265
271
266
272
return ubatches[i_next];
267
273
}
268
274
269
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_base () const {
275
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_base () const {
270
276
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
271
277
272
- return static_cast <const llama_kv_cache_unified_context *>(ctx_base .get ());
278
+ return static_cast <const llama_kv_cache_unified_state *>(state_base .get ());
273
279
}
274
280
275
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_swa () const {
281
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_swa () const {
276
282
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
277
283
278
- return static_cast <const llama_kv_cache_unified_context *>(ctx_swa .get ());
284
+ return static_cast <const llama_kv_cache_unified_state *>(state_swa .get ());
279
285
}
0 commit comments