@@ -210,7 +210,7 @@ bool llama_batch_allocr::init(
210
210
LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
211
211
212
212
llama_ubatch ubatch {
213
- /* .equal_seqs =*/ false ,
213
+ /* .b_equal_seqs =*/ false ,
214
214
/* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
215
215
/* .n_seq_tokens =*/ (uint32_t ) 1 ,
216
216
/* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
@@ -223,6 +223,7 @@ bool llama_batch_allocr::init(
223
223
/* .seq_id_unq =*/ this ->seq_id_unq .data (),
224
224
/* .seq_idx =*/ this ->seq_idx .data (),
225
225
/* .output =*/ batch.logits ,
226
+ /* .data =*/ {},
226
227
};
227
228
228
229
ubatch_print (ubatch, debug);
@@ -366,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
366
367
clear ();
367
368
split_reset ();
368
369
369
- ubatches. emplace_back ();
370
+ auto udata = std::make_shared<llama_ubatch:: data_t > ();
370
371
371
- auto & ubatch = ubatches.back ();
372
-
373
- ubatch.token .resize (n_tokens);
374
- ubatch.embd .clear ();
375
- ubatch.pos .resize (n_tokens);
376
- ubatch.n_seq_id .resize (n_tokens);
377
- ubatch.seq_id .resize (n_tokens);
378
- ubatch.seq_id_unq .resize (0 );
379
- ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
380
- ubatch.output .resize (n_tokens);
372
+ udata->token .resize (n_tokens);
373
+ udata->embd .clear ();
374
+ udata->pos .resize (n_tokens);
375
+ udata->n_seq_id .resize (n_tokens);
376
+ udata->seq_id .resize (n_tokens);
377
+ udata->seq_id_unq .resize (0 );
378
+ udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
379
+ udata->output .resize (n_tokens);
381
380
382
381
for (uint32_t s = 0 ; s < n_seqs; ++s) {
383
- ubatch. seq_idx [s] = s;
384
- ubatch. seq_id_unq .push_back (s);
382
+ udata-> seq_idx [s] = s;
383
+ udata-> seq_id_unq .push_back (s);
385
384
}
386
385
387
386
llama_ubatch res {
388
- /* .equal_seqs =*/ true ,
387
+ /* .b_equal_seqs =*/ true ,
389
388
/* .n_tokens =*/ n_tokens,
390
389
/* .n_seq_tokens =*/ n_seq_tokens,
391
390
/* .n_seqs =*/ n_seqs,
392
391
/* .n_seqs_unq =*/ n_seqs,
393
392
394
- /* .token =*/ ubatch. token .data (),
393
+ /* .token =*/ udata-> token .data (),
395
394
/* .embd =*/ nullptr ,
396
- /* .pos =*/ ubatch.pos .data (),
397
- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
398
- /* .seq_id =*/ ubatch.seq_id .data (),
399
- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
400
- /* .seq_idx =*/ ubatch.seq_idx .data (),
401
- /* .output =*/ ubatch.output .data (),
395
+ /* .pos =*/ udata->pos .data (),
396
+ /* .n_seq_id =*/ udata->n_seq_id .data (),
397
+ /* .seq_id =*/ udata->seq_id .data (),
398
+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
399
+ /* .seq_idx =*/ udata->seq_idx .data (),
400
+ /* .output =*/ udata->output .data (),
401
+ /* .data =*/ std::move (udata),
402
402
};
403
403
404
404
return res;
@@ -439,8 +439,6 @@ void llama_batch_allocr::split_reset() {
439
439
440
440
used.clear ();
441
441
used.resize (get_n_tokens (), false );
442
-
443
- ubatches.clear ();
444
442
}
445
443
446
444
llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
@@ -655,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655
653
656
654
assert (n_tokens%n_seqs == 0 );
657
655
658
- ubatches.emplace_back ();
659
-
660
- auto & ubatch = ubatches.back ();
656
+ auto udata = std::make_shared<llama_ubatch::data_t >();
661
657
662
658
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1 ;
663
659
664
660
const int64_t n_embd_all = batch.embd ? (int64_t ) n_tokens*n_embd : 0 ;
665
661
const int64_t n_pos_all = (int64_t ) n_tokens*n_pos_cur;
666
662
667
- ubatch. token .resize (n_tokens);
668
- ubatch. embd .resize (n_embd_all);
669
- ubatch. pos .resize (n_pos_all);
670
- ubatch. n_seq_id .resize (n_tokens);
671
- ubatch. seq_id .resize (n_tokens);
672
- ubatch. seq_id_unq .resize (0 );
673
- ubatch. seq_idx .resize (LLAMA_MAX_SEQ, -1 );
674
- ubatch. output .resize (n_tokens);
663
+ udata-> token .resize (n_tokens);
664
+ udata-> embd .resize (n_embd_all);
665
+ udata-> pos .resize (n_pos_all);
666
+ udata-> n_seq_id .resize (n_tokens);
667
+ udata-> seq_id .resize (n_tokens);
668
+ udata-> seq_id_unq .resize (0 );
669
+ udata-> seq_idx .resize (LLAMA_MAX_SEQ, -1 );
670
+ udata-> output .resize (n_tokens);
675
671
676
672
seq_set_t seq_set_unq;
677
673
678
674
for (size_t i = 0 ; i < idxs.size (); ++i) {
679
675
if (batch.token ) {
680
- ubatch. token [i] = batch.token [idxs[i]];
676
+ udata-> token [i] = batch.token [idxs[i]];
681
677
}
682
678
683
679
if (batch.embd ) {
684
- memcpy (ubatch. embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
680
+ memcpy (udata-> embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
685
681
}
686
682
687
683
for (int j = 0 ; j < n_pos_cur; ++j) {
688
- ubatch. pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens + idxs[i]];
684
+ udata-> pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens + idxs[i]];
689
685
}
690
686
691
- ubatch. n_seq_id [i] = batch.n_seq_id [idxs[i]];
692
- ubatch. seq_id [i] = batch.seq_id [idxs[i]];
693
- ubatch. output [i] = batch.logits [idxs[i]];
687
+ udata-> n_seq_id [i] = batch.n_seq_id [idxs[i]];
688
+ udata-> seq_id [i] = batch.seq_id [idxs[i]];
689
+ udata-> output [i] = batch.logits [idxs[i]];
694
690
695
- for (int s = 0 ; s < ubatch. n_seq_id [i]; ++s) {
696
- seq_set_unq.set (ubatch. seq_id [i][s]);
691
+ for (int s = 0 ; s < udata-> n_seq_id [i]; ++s) {
692
+ seq_set_unq.set (udata-> seq_id [i][s]);
697
693
}
698
694
699
- if (ubatch. output [i]) {
695
+ if (udata-> output [i]) {
700
696
out_ids.push_back (idxs[i]);
701
697
}
702
698
}
703
699
704
700
for (uint32_t s = 0 ; s < n_seq_max; ++s) {
705
701
if (seq_set_unq.test (s)) {
706
- ubatch. seq_idx [s] = ubatch. seq_id_unq .size ();
707
- ubatch. seq_id_unq .push_back (s);
702
+ udata-> seq_idx [s] = udata-> seq_id_unq .size ();
703
+ udata-> seq_id_unq .push_back (s);
708
704
}
709
705
}
710
706
711
707
llama_ubatch res {
712
- /* .equal_seqs =*/ equal_seqs,
708
+ /* .b_equal_seqs =*/ equal_seqs,
713
709
/* .n_tokens =*/ n_tokens,
714
710
/* .n_seq_tokens =*/ n_tokens/n_seqs,
715
711
/* .n_seqs =*/ n_seqs,
716
- /* .n_seqs_unq =*/ (uint32_t ) ubatch.seq_id_unq .size (),
717
-
718
- /* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
719
- /* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
720
- /* .pos =*/ ubatch.pos .data (),
721
- /* .n_seq_id =*/ ubatch.n_seq_id .data (),
722
- /* .seq_id =*/ ubatch.seq_id .data (),
723
- /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
724
- /* .seq_idx =*/ ubatch.seq_idx .data (),
725
- /* .output =*/ ubatch.output .data (),
712
+ /* .n_seqs_unq =*/ (uint32_t ) udata->seq_id_unq .size (),
713
+
714
+ /* .token =*/ batch.token ? udata->token .data () : nullptr ,
715
+ /* .embd =*/ batch.embd ? udata->embd .data () : nullptr ,
716
+ /* .pos =*/ udata->pos .data (),
717
+ /* .n_seq_id =*/ udata->n_seq_id .data (),
718
+ /* .seq_id =*/ udata->seq_id .data (),
719
+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
720
+ /* .seq_idx =*/ udata->seq_idx .data (),
721
+ /* .output =*/ udata->output .data (),
722
+ /* .data =*/ std::move (udata),
726
723
};
727
724
728
725
if (debug > 0 ) {
729
- LLAMA_LOG_DEBUG (" %s: added ubatch %d to split:\n " , __func__, ( int ) ubatches. size () - 1 );
726
+ LLAMA_LOG_DEBUG (" %s: added ubatch to split:\n " , __func__);
730
727
731
728
ubatch_print (res, debug);
732
729
}
@@ -736,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
736
733
737
734
void llama_batch_allocr::ubatch_print (const llama_ubatch & ubatch, int debug) {
738
735
if (debug > 0 ) {
739
- LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs );
736
+ LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs () );
740
737
LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
741
738
LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
742
739
LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
0 commit comments