1
1
#include " llama-batch.h"
2
2
3
3
#include " llama-impl.h"
4
- #include " llama-cparams.h"
5
4
#include " llama-vocab.h"
6
5
#include " llama-memory.h"
7
6
@@ -294,6 +293,8 @@ llama_batch_allocr::llama_batch_allocr() {
294
293
for (auto & cur : seq_cpl) {
295
294
cur.resize (LLAMA_MAX_SEQ);
296
295
}
296
+
297
+ seq_idx.resize (LLAMA_MAX_SEQ);
297
298
}
298
299
299
300
bool llama_batch_allocr::init (
@@ -303,6 +304,8 @@ bool llama_batch_allocr::init(
303
304
bool embd_all) {
304
305
clear ();
305
306
307
+ split_reset ();
308
+
306
309
batch = batch_inp;
307
310
308
311
GGML_ASSERT (batch.n_tokens > 0 );
@@ -433,6 +436,21 @@ bool llama_batch_allocr::init(
433
436
}
434
437
}
435
438
439
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
440
+ seq_set_t cur;
441
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
442
+ cur.set (batch.seq_id [i][s]);
443
+ }
444
+
445
+ seq_set.push_back (cur);
446
+
447
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
448
+ seq_idx[batch.seq_id [i][s]].push_back (i);
449
+ }
450
+
451
+ seq_set_map[cur].push_back (i);
452
+ }
453
+
436
454
if (debug > 0 ) {
437
455
LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
438
456
LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
@@ -532,17 +550,47 @@ bool llama_batch_allocr::init(
532
550
}
533
551
}
534
552
553
+ {
554
+ seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
555
+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
556
+ cur_seq_set[s].set ();
557
+ }
558
+
559
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
560
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
561
+ const llama_seq_id seq_id = batch.seq_id [i][s];
562
+
563
+ cur_seq_set[seq_id] &= seq_set[seq_id];
564
+
565
+ if (cur_seq_set[seq_id].none ()) {
566
+ LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets\n " , __func__, seq_id);
567
+ return false ;
568
+ }
569
+ }
570
+ }
571
+ }
572
+
573
+ // TODO: check that positions are increasing
574
+
535
575
return true ;
536
576
}
537
577
538
578
const llama_batch & llama_batch_allocr::get_batch () const {
539
579
return batch;
540
580
}
541
581
582
+ uint32_t llama_batch_allocr::get_n_tokens () const {
583
+ return pos.size ();
584
+ }
585
+
542
586
uint32_t llama_batch_allocr::get_n_outputs () const {
543
587
return n_outputs;
544
588
}
545
589
590
+ std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
591
+ return out_ids;
592
+ }
593
+
546
594
llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
547
595
return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
548
596
}
@@ -551,6 +599,215 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
599
return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
552
600
}
553
601
602
+ void llama_batch_allocr::split_reset () {
603
+ out_ids.clear ();
604
+
605
+ used.clear ();
606
+ used.resize (get_n_tokens (), false );
607
+
608
+ ubatches.clear ();
609
+ }
610
+
611
+ llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
612
+ llama_ubatch res {
613
+ /* .equal_seqs =*/ false ,
614
+ /* .n_tokens =*/ 0 ,
615
+ /* .n_seq_tokens =*/ 1 ,
616
+ /* .n_seqs =*/ 0 ,
617
+
618
+ /* .token =*/ nullptr ,
619
+ /* .embd =*/ nullptr ,
620
+ /* .pos =*/ nullptr ,
621
+ /* .n_seq_id =*/ nullptr ,
622
+ /* .seq_id =*/ nullptr ,
623
+ /* .output =*/ nullptr
624
+ };
625
+
626
+ uint32_t cur_idx = 0 ;
627
+ while (cur_idx < used.size () && used[cur_idx]) {
628
+ ++cur_idx;
629
+ }
630
+
631
+ if (cur_idx >= used.size ()) {
632
+ return res;
633
+ }
634
+
635
+ std::vector<int32_t > idxs;
636
+
637
+ while (true ) {
638
+ res.n_tokens ++;
639
+ res.n_seqs ++;
640
+
641
+ idxs.push_back (cur_idx);
642
+
643
+ if (output[cur_idx] != 0 ) {
644
+ out_ids.push_back (cur_idx);
645
+ }
646
+
647
+ used[cur_idx] = true ;
648
+
649
+ ++cur_idx;
650
+
651
+ if (cur_idx >= used.size ()) {
652
+ break ;
653
+ }
654
+
655
+ if (res.n_tokens >= n_ubatch) {
656
+ break ;
657
+ }
658
+ }
659
+
660
+ add_ubatch (res, idxs);
661
+
662
+ return res;
663
+ }
664
+
665
+ llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
666
+ llama_ubatch res {
667
+ /* .equal_seqs =*/ true ,
668
+ /* .n_tokens =*/ 0 ,
669
+ /* .n_seq_tokens =*/ 0 ,
670
+ /* .n_seqs =*/ 0 ,
671
+
672
+ /* .token =*/ nullptr ,
673
+ /* .embd =*/ nullptr ,
674
+ /* .pos =*/ nullptr ,
675
+ /* .n_seq_id =*/ nullptr ,
676
+ /* .seq_id =*/ nullptr ,
677
+ /* .output =*/ nullptr
678
+ };
679
+
680
+ std::vector<seq_set_t > cur_seq_set;
681
+
682
+ // determine the sequence sets participating in this ubatch
683
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
684
+ if (used[i]) {
685
+ continue ;
686
+ }
687
+
688
+ for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
689
+ // no overlap with existing sequence sets:
690
+ if ((cur_seq_set[s] & seq_set[i]).none ()) {
691
+ cur_seq_set.push_back (seq_set[i]);
692
+
693
+ if (cur_seq_set.size () > (size_t ) n_ubatch) {
694
+ break ;
695
+ }
696
+ }
697
+ }
698
+ }
699
+
700
+ res.n_seqs = cur_seq_set.size ();
701
+
702
+ std::vector<int32_t > cur_idx (cur_seq_set.size (), 0 );
703
+
704
+ for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
705
+ while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
706
+ ++cur_idx[s];
707
+ }
708
+ }
709
+
710
+ std::vector<int32_t > idxs;
711
+
712
+ // TODO: reorder from 012301230123..., to 000...111...222...333...
713
+ while (true ) {
714
+ bool can_expand = true ;
715
+
716
+ for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
717
+ if (cur_idx[s] >= (int32_t ) seq_set_map[cur_seq_set[s]].size ()) {
718
+ can_expand = false ;
719
+ break ;
720
+ }
721
+ }
722
+
723
+ if (!can_expand) {
724
+ break ;
725
+ }
726
+
727
+ res.n_tokens += res.n_seqs ;
728
+
729
+ for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
730
+ const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
731
+ idxs.push_back (idx);
732
+
733
+ if (output[idx] != 0 ) {
734
+ out_ids.push_back (idx);
735
+ }
736
+
737
+ used[idx] = true ;
738
+
739
+ ++cur_idx[s];
740
+ }
741
+
742
+ if (res.n_tokens + res.n_seqs > n_ubatch) {
743
+ break ;
744
+ }
745
+ }
746
+
747
+ add_ubatch (res, idxs);
748
+
749
+ return res;
750
+ }
751
+
752
+ llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
753
+ llama_ubatch res {
754
+ /* .equal_seqs =*/ true ,
755
+ /* .n_tokens =*/ 0 ,
756
+ /* .n_seq_tokens =*/ 0 ,
757
+ /* .n_seqs =*/ 1 ,
758
+
759
+ /* .token =*/ nullptr ,
760
+ /* .embd =*/ nullptr ,
761
+ /* .pos =*/ nullptr ,
762
+ /* .n_seq_id =*/ nullptr ,
763
+ /* .seq_id =*/ nullptr ,
764
+ /* .output =*/ nullptr ,
765
+ };
766
+
767
+ uint32_t cur_idx = 0 ;
768
+ while (cur_idx < used.size () && used[cur_idx]) {
769
+ ++cur_idx;
770
+ }
771
+
772
+ if (cur_idx >= used.size ()) {
773
+ return res;
774
+ }
775
+
776
+ auto cur_seq_set = seq_set[cur_idx];
777
+
778
+ std::vector<int32_t > idxs;
779
+
780
+ while (true ) {
781
+ res.n_tokens ++;
782
+
783
+ idxs.push_back (cur_idx);
784
+
785
+ if (output[cur_idx] != 0 ) {
786
+ out_ids.push_back (cur_idx);
787
+ }
788
+
789
+ used[cur_idx] = true ;
790
+
791
+ if (res.n_tokens >= n_ubatch) {
792
+ break ;
793
+ }
794
+
795
+ do {
796
+ ++cur_idx;
797
+ } while (cur_idx < get_n_tokens () && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
798
+
799
+ if (cur_idx == get_n_tokens ()) {
800
+ break ;
801
+ }
802
+
803
+ cur_seq_set = seq_set[cur_idx];
804
+ }
805
+
806
+ add_ubatch (res, idxs);
807
+
808
+ return res;
809
+ }
810
+
554
811
void llama_batch_allocr::clear () {
555
812
n_outputs = 0 ;
556
813
@@ -567,6 +824,47 @@ void llama_batch_allocr::clear() {
567
824
for (auto & cur : seq_cpl) {
568
825
std::fill (cur.begin (), cur.end (), false );
569
826
}
827
+
828
+ seq_set.clear ();
829
+
830
+ for (auto & cur : seq_idx) {
831
+ cur.clear ();
832
+ }
833
+
834
+ seq_set_map.clear ();
835
+ }
836
+
837
+ void llama_batch_allocr::add_ubatch (llama_ubatch & res, const std::vector<int32_t > & idxs) {
838
+ ubatches.emplace_back ();
839
+
840
+ auto & ubatch = ubatches.back ();
841
+
842
+ assert (res.n_tokens == idxs.size ());
843
+
844
+ const auto n_tokens = res.n_tokens ;
845
+
846
+ ubatch.token .resize (n_tokens);
847
+ // ubatch.embd.resize(0); // TODO
848
+ ubatch.pos .resize (n_tokens);
849
+ ubatch.n_seq_id .resize (n_tokens);
850
+ ubatch.seq_id .resize (n_tokens);
851
+ ubatch.output .resize (n_tokens);
852
+
853
+ for (size_t i = 0 ; i < idxs.size (); ++i) {
854
+ ubatch.token [i] = batch.token [idxs[i]];
855
+ // ubatch.embd[i] = batch.embd[idxs[i]]; // TODO
856
+ ubatch.pos [i] = batch.pos [idxs[i]];
857
+ ubatch.n_seq_id [i] = batch.n_seq_id [idxs[i]];
858
+ ubatch.seq_id [i] = batch.seq_id [idxs[i]];
859
+ ubatch.output [i] = batch.logits [idxs[i]];
860
+ }
861
+
862
+ res.token = ubatch.token .data ();
863
+ // res.embd = ubatch.embd.data(); // TODO
864
+ res.pos = ubatch.pos .data ();
865
+ res.n_seq_id = ubatch.n_seq_id .data ();
866
+ res.seq_id = ubatch.seq_id .data ();
867
+ res.output = ubatch.output .data ();
570
868
}
571
869
572
870
//
0 commit comments