@@ -130,42 +130,20 @@ bool llama_batch_allocr::init(
130
130
warn = true ;
131
131
}
132
132
}
133
-
134
- if (warn) {
135
- LLAMA_LOG_WARN (" %s: embeddings required but some input tokens were not marked as outputs -> overriding\n " , __func__);
136
-
137
- output.resize (batch.n_tokens , true );
138
- batch.logits = output.data ();
139
- }
140
- }
141
-
142
- //
143
- // compute stats
144
- //
145
-
146
- this ->n_embd = n_embd;
147
-
148
- // count the outputs in this batch
149
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
150
- n_outputs += batch.logits [i] != 0 ;
151
133
}
152
-
153
- // determine coupled sequences
154
- // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
156
- const llama_seq_id s0 = batch.seq_id [i][0 ];
157
-
158
- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
159
- const llama_seq_id s1 = batch.seq_id [i][s];
160
-
161
- seq_pos[s1].insert (batch.pos [i]);
162
-
163
- if (s > 0 ) {
164
- // mark that sequence s1 is coupled to s0
165
- seq_cpl[s1][s0] = true ;
166
-
167
- // note: tracking the other way around is not necessary for now
168
- // seq_cpl[s0][s1] = true;
134
+ if (batch->logits ) {
135
+ if (ubatch.equal_seqs ) {
136
+ for (size_t i = 0 ; i < length; ++i) {
137
+ size_t id = ids[seq.offset + i];
138
+ int8_t is_output = batch->logits [id];
139
+ ubatch.output [ubatch.n_tokens + i] = is_output;
140
+ if (is_output) { out_ids.push_back (id); }
141
+ }
142
+ } else {
143
+ // simple split
144
+ ubatch.output = batch->logits + seq.offset ;
145
+ for (size_t i = 0 ; i < length; ++i) {
146
+ if (ubatch.output [i] != 0 ) { out_ids.push_back (seq.offset + i); }
169
147
}
170
148
}
171
149
}
@@ -281,141 +259,49 @@ bool llama_batch_allocr::init(
281
259
}
282
260
}
283
261
284
- if (memory) {
285
- for (int32_t s0 = 0 ; s0 < LLAMA_MAX_SEQ; ++s0) {
286
- for (int32_t s1 = 0 ; s1 < LLAMA_MAX_SEQ; ++s1) {
287
- if (seq_cpl[s0][s1]) {
288
- if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
289
- memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
290
- LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
291
- return false ;
292
- }
293
- }
262
+ llama_ubatch llama_sbatch::split_equal (size_t n_ubatch) {
263
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
264
+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch->embd != nullptr );
265
+ if (!seq.empty ()) {
266
+ size_t length = 0 ;
267
+ size_t n_tokens_in_ubatch = 0 ;
268
+ GGML_ASSERT (seq[0 ].n_seq_id > 0 ); // should not be mixed with simple splits
269
+ // smallest first, because it's easier to split this way;
270
+ // starting from the end to pop in constant time.
271
+ for (size_t i = seq.size (); i-- > 0 ;) {
272
+ llama_sbatch_seq & s = seq[i];
273
+ GGML_ASSERT (s.length > 0 );
274
+ if (length == 0 ) {
275
+ length = s.length < n_ubatch ? s.length : n_ubatch;
294
276
}
277
+ add_seq_to_ubatch (ubatch, s, length);
278
+ n_tokens_in_ubatch += length;
279
+ // shared prompts can't be mixed with any of their sequences,
280
+ // so it's safer to compute them in their own ubatch
281
+ if (s.n_seq_id > 1 ) { break ; }
282
+ // stop when there isn't enough space for another sequence
283
+ if (length + n_tokens_in_ubatch > n_ubatch) { break ; }
295
284
}
296
285
}
297
-
298
- // disallow partial sequence sub-sets:
299
- //
300
- // invalid: x
301
- // i: 0 1 2 ...
302
- // ---------------------------------------
303
- // seq_id[i][0]: 0 0 1
304
- // seq_id[i][1]: 1 1 2
305
- // seq_id[i][2]: 2
306
- //
307
- // disallow decreasing sequence positions:
308
- //
309
- // invalid: x
310
- // i: 0 1 2 3 4 5 6 ...
311
- // ---------------------------------------
312
- // pos[i]: 4 5 0 1 6 2 3
313
- // seq_id[i][0]: 0 0 1 1 0 1 0
314
- //
315
- {
316
- seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
318
- cur_seq_set[s].set ();
319
- }
320
-
321
- llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322
- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
323
- cur_seq_pos[s] = -1 ;
324
- }
325
-
326
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
327
- const llama_pos pos = batch.pos [i];
328
-
329
- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
330
- const llama_seq_id seq_id = batch.seq_id [i][s];
331
-
332
- cur_seq_set[seq_id] &= seq_set[i];
333
-
334
- if (cur_seq_set[seq_id].none ()) {
335
- LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets (not allowed)\n " , __func__, seq_id);
336
- return false ;
337
- }
338
-
339
- if (pos < cur_seq_pos[seq_id]) {
340
- LLAMA_LOG_ERROR (" %s: sequence %d positions are decreasing (not allowed)\n " , __func__, seq_id);
341
- return false ;
342
- }
343
- }
344
- }
345
- }
346
-
347
- split_reset ();
348
-
349
- return true ;
286
+ return ubatch;
350
287
}
351
288
352
- llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs) {
353
- const uint32_t n_tokens = n_seq_tokens*n_seqs;
354
-
355
- clear ();
356
- split_reset ();
357
-
358
- ubatches.emplace_back ();
359
-
360
- auto & ubatch = ubatches.back ();
361
-
362
- ubatch.token .resize (n_tokens);
363
- ubatch.embd .clear ();
364
- ubatch.pos .resize (n_tokens);
365
- ubatch.n_seq_id .resize (n_tokens);
366
- ubatch.seq_id .resize (n_tokens);
367
- ubatch.seq_id_unq .resize (0 );
368
- ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
369
- ubatch.output .resize (n_tokens);
370
-
371
- for (uint32_t s = 0 ; s < n_seqs; ++s) {
372
- ubatch.seq_idx [s] = s;
373
- ubatch.seq_id_unq .push_back (s);
289
+ llama_ubatch llama_sbatch::split_seq (size_t n_ubatch) {
290
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
291
+ llama_ubatch ubatch = reserve_ubatch (n_ubatch, /* has_embd */ batch->embd != nullptr );
292
+ if (!seq.empty ()) {
293
+ llama_sbatch_seq & s = seq[seq.size () - 1 ];
294
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
295
+ GGML_ASSERT (s.n_seq_id > 0 ); // should not be mixed with simple splits
296
+ add_seq_to_ubatch (ubatch, s, length);
374
297
}
375
-
376
- llama_ubatch res {
377
- /* .equal_seqs =*/ true ,
378
- /* .n_tokens =*/ n_tokens,
379
- /* .n_seq_tokens =*/ n_seq_tokens,
380
- /* .n_seqs =*/ n_seqs,
381
- /* .n_seqs_unq =*/ n_seqs,
382
-
383
- /* .token =*/ ubatch.token .data (),
384
- /* .embd =*/ nullptr ,
385
- /* .pos =*/ ubatch.pos .data (),
386
- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
387
- /* .seq_id =*/ ubatch.seq_id .data (),
388
- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
389
- /* .seq_idx =*/ ubatch.seq_idx .data (),
390
- /* .output =*/ ubatch.output .data (),
391
- };
392
-
393
- return res;
298
+ return ubatch;
394
299
}
395
300
396
- const llama_batch & llama_batch_allocr::get_batch () const {
397
- return batch;
398
- }
399
-
400
- uint32_t llama_batch_allocr::get_n_tokens () const {
401
- return batch.n_tokens ;
402
- }
403
-
404
- uint32_t llama_batch_allocr::get_n_outputs () const {
405
- return n_outputs;
406
- }
407
-
408
- std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
409
- return out_ids;
410
- }
411
-
412
- llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
413
- return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
414
- }
415
-
416
- llama_pos llama_batch_allocr::seq_pos_max (llama_seq_id seq_id) const {
417
- return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
418
- }
301
+ llama_sbatch::llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split) {
302
+ GGML_ASSERT (batch.n_tokens >= 0 );
303
+ this ->batch = &batch;
304
+ this ->n_embd = n_embd;
419
305
420
306
void llama_batch_allocr::split_reset () {
421
307
out_ids.clear ();
0 commit comments