@@ -95,93 +95,77 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
95
return kv_swa->seq_pos_max (seq_id);
96
96
}
97
97
98
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch (llama_batch_allocr & balloc , uint32_t n_ubatch, bool embd_all ) {
99
- GGML_UNUSED (embd_all );
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch (const llama_batch & batch , uint32_t n_ubatch, bool embd_pooled, bool logits_all ) {
99
+ GGML_UNUSED (embd_pooled );
100
100
101
101
// first try simple split
102
102
do {
103
- balloc. split_reset ( );
103
+ auto sbatch = llama_sbatch (batch, hparams. n_embd , true , logits_all );
104
104
105
105
std::vector<llama_ubatch> ubatches;
106
- while (true ) {
107
- auto ubatch = balloc.split_simple (n_ubatch);
108
106
109
- if (ubatch.n_tokens == 0 ) {
110
- break ;
111
- }
107
+ while (sbatch.n_tokens > 0 ) {
108
+ auto ubatch = sbatch.split_simple (n_ubatch);
112
109
113
- ubatches.push_back (std::move ( ubatch)); // NOLINT
110
+ ubatches.push_back (ubatch);
114
111
}
115
112
116
- if (balloc. get_n_used () < balloc. get_n_tokens ()) {
117
- // failed to find a suitable split
113
+ auto heads_base = kv_base-> prepare (ubatches);
114
+ if (heads_base. empty ()) {
118
115
break ;
119
116
}
120
117
121
- auto sinfos_base = kv_base ->prepare (ubatches);
122
- if (sinfos_base .empty ()) {
118
+ auto heads_swa = kv_swa ->prepare (ubatches);
119
+ if (heads_swa .empty ()) {
123
120
break ;
124
121
}
125
122
126
- auto sinfos_swa = kv_swa->prepare (ubatches);
127
- if (sinfos_swa.empty ()) {
128
- break ;
129
- }
123
+ assert (heads_base.size () == heads_swa.size ());
130
124
131
- assert (sinfos_base.size () == sinfos_swa.size ());
132
-
133
- return std::make_unique<llama_kv_cache_unified_iswa_context>(
134
- this , std::move (sinfos_base), std::move (sinfos_swa), std::move (ubatches));
125
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
+ this , std::move (sbatch), std::move (heads_base), std::move (heads_swa), std::move (ubatches));
135
127
} while (false );
136
128
137
129
// if it fails, try equal split
138
130
do {
139
- balloc. split_reset ( );
131
+ auto sbatch = llama_sbatch (batch, hparams. n_embd , false , logits_all );
140
132
141
133
std::vector<llama_ubatch> ubatches;
142
- while (true ) {
143
- auto ubatch = balloc.split_equal (n_ubatch, false );
144
134
145
- if (ubatch.n_tokens == 0 ) {
146
- break ;
147
- }
135
+ while (sbatch.n_tokens > 0 ) {
136
+ auto ubatch = sbatch.split_equal (n_ubatch);
148
137
149
- ubatches.push_back (std::move ( ubatch)); // NOLINT
138
+ ubatches.push_back (ubatch);
150
139
}
151
140
152
- if (balloc. get_n_used () < balloc. get_n_tokens ()) {
153
- // failed to find a suitable split
141
+ auto heads_base = kv_base-> prepare (ubatches);
142
+ if (heads_base. empty ()) {
154
143
break ;
155
144
}
156
145
157
- auto sinfos_base = kv_base ->prepare (ubatches);
158
- if (sinfos_base .empty ()) {
146
+ auto heads_swa = kv_swa ->prepare (ubatches);
147
+ if (heads_swa .empty ()) {
159
148
break ;
160
149
}
161
150
162
- auto sinfos_swa = kv_swa->prepare (ubatches);
163
- if (sinfos_swa.empty ()) {
164
- break ;
165
- }
166
-
167
- assert (sinfos_base.size () == sinfos_swa.size ());
151
+ assert (heads_base.size () == heads_swa.size ());
168
152
169
- return std::make_unique<llama_kv_cache_unified_iswa_context >(
170
- this , std::move (sinfos_base ), std::move (sinfos_swa ), std::move (ubatches));
153
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
154
+ this , std::move (sbatch ), std::move (heads_base), std::move (heads_swa ), std::move (ubatches));
171
155
} while (false );
172
156
173
157
// TODO: if we fail again, we should attempt different splitting strategies
174
158
// but to do that properly, we first have to refactor the batches to be more flexible
175
159
176
- return std::make_unique<llama_kv_cache_unified_iswa_context >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
177
161
}
178
162
179
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full () {
180
- return std::make_unique<llama_kv_cache_unified_iswa_context >(this );
163
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
164
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this );
181
165
}
182
166
183
- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
184
- return std::make_unique<llama_kv_cache_unified_iswa_context >(this , lctx, optimize);
167
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
168
+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this , lctx, optimize);
185
169
}
186
170
187
171
bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -207,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
207
191
}
208
192
209
193
//
210
- // llama_kv_cache_unified_iswa_context
194
+ // llama_kv_cache_unified_iswa_state
211
195
//
212
196
213
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (llama_memory_status status) : status(status) {}
197
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
198
+
199
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
200
+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201
+ state_base = kv->get_base ()->init_full ();
202
+ state_swa = kv->get_swa ()->init_full ();
214
203
215
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
216
- llama_kv_cache_unified_iswa * kv) :
217
- ctx_base(kv->get_base ()->init_full()),
218
- ctx_swa (kv->get_swa ()->init_full()),
219
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
204
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
220
205
}
221
206
222
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
207
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
223
208
llama_kv_cache_unified_iswa * kv,
224
209
llama_context * lctx,
225
- bool optimize) :
226
- ctx_base(kv->get_base ()->init_update(lctx, optimize)),
227
- ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
228
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
210
+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211
+ state_base = kv->get_base ()->init_update (lctx, optimize);
212
+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
213
+
214
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
229
215
}
230
216
231
- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
217
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
232
218
llama_kv_cache_unified_iswa * kv,
233
- slot_info_vec_t sinfos_base,
234
- slot_info_vec_t sinfos_swa,
235
- std::vector<llama_ubatch> ubatches) :
236
- ubatches(std::move(ubatches)),
219
+ llama_sbatch sbatch,
220
+ std::vector<uint32_t > heads_base,
221
+ std::vector<uint32_t > heads_swa,
222
+ std::vector<llama_ubatch> ubatches)
223
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
224
+ sbatch(std::move(sbatch)),
225
+ ubatches(std::move(ubatches)) {
237
226
// note: here we copy the ubatches. not sure if this is ideal
238
- ctx_base(new llama_kv_cache_unified_context(kv->get_base (), std::move(sinfos_base), this->ubatches)),
239
- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
240
- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
227
+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
228
+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
229
+
230
+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
241
231
}
242
232
243
- llama_kv_cache_unified_iswa_context :: ~llama_kv_cache_unified_iswa_context () = default ;
233
+ llama_kv_cache_unified_iswa_state :: ~llama_kv_cache_unified_iswa_state () = default ;
244
234
245
- bool llama_kv_cache_unified_iswa_context ::next () {
235
+ bool llama_kv_cache_unified_iswa_state ::next () {
246
236
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
247
237
248
- ctx_base ->next ();
249
- ctx_swa ->next ();
238
+ state_base ->next ();
239
+ state_swa ->next ();
250
240
251
241
if (++i_next >= ubatches.size ()) {
252
242
return false ;
@@ -255,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
255
245
return true ;
256
246
}
257
247
258
- bool llama_kv_cache_unified_iswa_context ::apply () {
259
- assert (! llama_memory_status_is_fail ( status) );
248
+ bool llama_kv_cache_unified_iswa_state ::apply () {
249
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS );
260
250
261
251
bool res = true ;
262
252
263
- res = res & ctx_base ->apply ();
264
- res = res & ctx_swa ->apply ();
253
+ res = res & state_base ->apply ();
254
+ res = res & state_swa ->apply ();
265
255
266
256
return res;
267
257
}
268
258
269
- llama_memory_status llama_kv_cache_unified_iswa_context::get_status () const {
259
+ std::vector<int64_t > & llama_kv_cache_unified_iswa_state::out_ids () {
260
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
261
+
262
+ return sbatch.out_ids ;
263
+ }
264
+
265
+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status () const {
270
266
return status;
271
267
}
272
268
273
- const llama_ubatch & llama_kv_cache_unified_iswa_context ::get_ubatch () const {
269
+ const llama_ubatch & llama_kv_cache_unified_iswa_state ::get_ubatch () const {
274
270
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
275
271
276
272
return ubatches[i_next];
277
273
}
278
274
279
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_base () const {
275
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_base () const {
280
276
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
281
277
282
- return static_cast <const llama_kv_cache_unified_context *>(ctx_base .get ());
278
+ return static_cast <const llama_kv_cache_unified_state *>(state_base .get ());
283
279
}
284
280
285
- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_swa () const {
281
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_swa () const {
286
282
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
287
283
288
- return static_cast <const llama_kv_cache_unified_context *>(ctx_swa .get ());
284
+ return static_cast <const llama_kv_cache_unified_state *>(state_swa .get ());
289
285
}
0 commit comments