Skip to content

Commit c4faef3

Browse files
committed
ubatch : new splitting logic (wip)
ggml-ci
1 parent e434e69 commit c4faef3

11 files changed

+421
-67
lines changed

src/llama-batch.cpp

Lines changed: 299 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "llama-batch.h"
22

33
#include "llama-impl.h"
4-
#include "llama-cparams.h"
54
#include "llama-vocab.h"
65
#include "llama-memory.h"
76

@@ -294,6 +293,8 @@ llama_batch_allocr::llama_batch_allocr() {
294293
for (auto & cur : seq_cpl) {
295294
cur.resize(LLAMA_MAX_SEQ);
296295
}
296+
297+
seq_idx.resize(LLAMA_MAX_SEQ);
297298
}
298299

299300
bool llama_batch_allocr::init(
@@ -303,6 +304,8 @@ bool llama_batch_allocr::init(
303304
bool embd_all) {
304305
clear();
305306

307+
split_reset();
308+
306309
batch = batch_inp;
307310

308311
GGML_ASSERT(batch.n_tokens > 0);
@@ -433,6 +436,21 @@ bool llama_batch_allocr::init(
433436
}
434437
}
435438

439+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
440+
seq_set_t cur;
441+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
442+
cur.set(batch.seq_id[i][s]);
443+
}
444+
445+
seq_set.push_back(cur);
446+
447+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
448+
seq_idx[batch.seq_id[i][s]].push_back(i);
449+
}
450+
451+
seq_set_map[cur].push_back(i);
452+
}
453+
436454
if (debug > 0) {
437455
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438456
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
@@ -532,17 +550,47 @@ bool llama_batch_allocr::init(
532550
}
533551
}
534552

553+
{
554+
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
555+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
556+
cur_seq_set[s].set();
557+
}
558+
559+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
560+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
561+
const llama_seq_id seq_id = batch.seq_id[i][s];
562+
563+
cur_seq_set[seq_id] &= seq_set[seq_id];
564+
565+
if (cur_seq_set[seq_id].none()) {
566+
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets\n", __func__, seq_id);
567+
return false;
568+
}
569+
}
570+
}
571+
}
572+
573+
// TODO: check that positions are increasing
574+
535575
return true;
536576
}
537577

538578
const llama_batch & llama_batch_allocr::get_batch() const {
539579
return batch;
540580
}
541581

582+
uint32_t llama_batch_allocr::get_n_tokens() const {
583+
return pos.size();
584+
}
585+
542586
uint32_t llama_batch_allocr::get_n_outputs() const {
543587
return n_outputs;
544588
}
545589

590+
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
591+
return out_ids;
592+
}
593+
546594
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547595
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548596
}
@@ -551,6 +599,215 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551599
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552600
}
553601

602+
void llama_batch_allocr::split_reset() {
603+
out_ids.clear();
604+
605+
used.clear();
606+
used.resize(get_n_tokens(), false);
607+
608+
ubatches.clear();
609+
}
610+
611+
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
612+
llama_ubatch res {
613+
/*.equal_seqs =*/ false,
614+
/*.n_tokens =*/ 0,
615+
/*.n_seq_tokens =*/ 1,
616+
/*.n_seqs =*/ 0,
617+
618+
/*.token =*/ nullptr,
619+
/*.embd =*/ nullptr,
620+
/*.pos =*/ nullptr,
621+
/*.n_seq_id =*/ nullptr,
622+
/*.seq_id =*/ nullptr,
623+
/*.output =*/ nullptr
624+
};
625+
626+
uint32_t cur_idx = 0;
627+
while (cur_idx < used.size() && used[cur_idx]) {
628+
++cur_idx;
629+
}
630+
631+
if (cur_idx >= used.size()) {
632+
return res;
633+
}
634+
635+
std::vector<int32_t> idxs;
636+
637+
while (true) {
638+
res.n_tokens++;
639+
res.n_seqs++;
640+
641+
idxs.push_back(cur_idx);
642+
643+
if (output[cur_idx] != 0) {
644+
out_ids.push_back(cur_idx);
645+
}
646+
647+
used[cur_idx] = true;
648+
649+
++cur_idx;
650+
651+
if (cur_idx >= used.size()) {
652+
break;
653+
}
654+
655+
if (res.n_tokens >= n_ubatch) {
656+
break;
657+
}
658+
}
659+
660+
add_ubatch(res, idxs);
661+
662+
return res;
663+
}
664+
665+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
666+
llama_ubatch res {
667+
/*.equal_seqs =*/ true,
668+
/*.n_tokens =*/ 0,
669+
/*.n_seq_tokens =*/ 0,
670+
/*.n_seqs =*/ 0,
671+
672+
/*.token =*/ nullptr,
673+
/*.embd =*/ nullptr,
674+
/*.pos =*/ nullptr,
675+
/*.n_seq_id =*/ nullptr,
676+
/*.seq_id =*/ nullptr,
677+
/*.output =*/ nullptr
678+
};
679+
680+
std::vector<seq_set_t> cur_seq_set;
681+
682+
// determine the sequence sets participating in this ubatch
683+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
684+
if (used[i]) {
685+
continue;
686+
}
687+
688+
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
689+
// no overlap with existing sequence sets:
690+
if ((cur_seq_set[s] & seq_set[i]).none()) {
691+
cur_seq_set.push_back(seq_set[i]);
692+
693+
if (cur_seq_set.size() > (size_t) n_ubatch) {
694+
break;
695+
}
696+
}
697+
}
698+
}
699+
700+
res.n_seqs = cur_seq_set.size();
701+
702+
std::vector<int32_t> cur_idx(cur_seq_set.size(), 0);
703+
704+
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
705+
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
706+
++cur_idx[s];
707+
}
708+
}
709+
710+
std::vector<int32_t> idxs;
711+
712+
// TODO: reorder from 012301230123..., to 000...111...222...333...
713+
while (true) {
714+
bool can_expand = true;
715+
716+
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
717+
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
718+
can_expand = false;
719+
break;
720+
}
721+
}
722+
723+
if (!can_expand) {
724+
break;
725+
}
726+
727+
res.n_tokens += res.n_seqs;
728+
729+
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
730+
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
731+
idxs.push_back(idx);
732+
733+
if (output[idx] != 0) {
734+
out_ids.push_back(idx);
735+
}
736+
737+
used[idx] = true;
738+
739+
++cur_idx[s];
740+
}
741+
742+
if (res.n_tokens + res.n_seqs > n_ubatch) {
743+
break;
744+
}
745+
}
746+
747+
add_ubatch(res, idxs);
748+
749+
return res;
750+
}
751+
752+
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
753+
llama_ubatch res {
754+
/*.equal_seqs =*/ true,
755+
/*.n_tokens =*/ 0,
756+
/*.n_seq_tokens =*/ 0,
757+
/*.n_seqs =*/ 1,
758+
759+
/*.token =*/ nullptr,
760+
/*.embd =*/ nullptr,
761+
/*.pos =*/ nullptr,
762+
/*.n_seq_id =*/ nullptr,
763+
/*.seq_id =*/ nullptr,
764+
/*.output =*/ nullptr,
765+
};
766+
767+
uint32_t cur_idx = 0;
768+
while (cur_idx < used.size() && used[cur_idx]) {
769+
++cur_idx;
770+
}
771+
772+
if (cur_idx >= used.size()) {
773+
return res;
774+
}
775+
776+
auto cur_seq_set = seq_set[cur_idx];
777+
778+
std::vector<int32_t> idxs;
779+
780+
while (true) {
781+
res.n_tokens++;
782+
783+
idxs.push_back(cur_idx);
784+
785+
if (output[cur_idx] != 0) {
786+
out_ids.push_back(cur_idx);
787+
}
788+
789+
used[cur_idx] = true;
790+
791+
if (res.n_tokens >= n_ubatch) {
792+
break;
793+
}
794+
795+
do {
796+
++cur_idx;
797+
} while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
798+
799+
if (cur_idx == get_n_tokens()) {
800+
break;
801+
}
802+
803+
cur_seq_set = seq_set[cur_idx];
804+
}
805+
806+
add_ubatch(res, idxs);
807+
808+
return res;
809+
}
810+
554811
void llama_batch_allocr::clear() {
555812
n_outputs = 0;
556813

@@ -567,6 +824,47 @@ void llama_batch_allocr::clear() {
567824
for (auto & cur : seq_cpl) {
568825
std::fill(cur.begin(), cur.end(), false);
569826
}
827+
828+
seq_set.clear();
829+
830+
for (auto & cur : seq_idx) {
831+
cur.clear();
832+
}
833+
834+
seq_set_map.clear();
835+
}
836+
837+
void llama_batch_allocr::add_ubatch(llama_ubatch & res, const std::vector<int32_t> & idxs) {
838+
ubatches.emplace_back();
839+
840+
auto & ubatch = ubatches.back();
841+
842+
assert(res.n_tokens == idxs.size());
843+
844+
const auto n_tokens = res.n_tokens;
845+
846+
ubatch.token.resize(n_tokens);
847+
//ubatch.embd.resize(0); // TODO
848+
ubatch.pos.resize(n_tokens);
849+
ubatch.n_seq_id.resize(n_tokens);
850+
ubatch.seq_id.resize(n_tokens);
851+
ubatch.output.resize(n_tokens);
852+
853+
for (size_t i = 0; i < idxs.size(); ++i) {
854+
ubatch.token[i] = batch.token[idxs[i]];
855+
//ubatch.embd[i] = batch.embd[idxs[i]]; // TODO
856+
ubatch.pos[i] = batch.pos[idxs[i]];
857+
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
858+
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
859+
ubatch.output[i] = batch.logits[idxs[i]];
860+
}
861+
862+
res.token = ubatch.token.data();
863+
//res.embd = ubatch.embd.data(); // TODO
864+
res.pos = ubatch.pos.data();
865+
res.n_seq_id = ubatch.n_seq_id.data();
866+
res.seq_id = ubatch.seq_id.data();
867+
res.output = ubatch.output.data();
570868
}
571869

572870
//

0 commit comments

Comments
 (0)