Skip to content

Commit dfb86c5

Browse files
ggerganovMinh141120
authored andcommitted
batch : rework llama_batch_allocr (ggml-org#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
1 parent 40eefac commit dfb86c5

File tree

6 files changed

+161
-112
lines changed

6 files changed

+161
-112
lines changed

src/llama-batch.cpp

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "llama-batch.h"
22

33
#include "llama-impl.h"
4+
#include "llama-cparams.h"
45
#include "llama-vocab.h"
5-
#include "llama-memory.h"
66

77
#include <cassert>
88
#include <cstring>
@@ -611,37 +611,48 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
611611
}
612612
}
613613

614-
ss_seq_id_unq << "]";
615-
ss_seq_idx << "]";
616-
617-
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
618-
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
619-
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
620-
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
621-
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
622-
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
623-
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
624-
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
625-
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
626-
627-
if (debug > 1) {
628-
int seq_id_max = 0;
629-
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
630-
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
631-
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
632-
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
633-
}
614+
// keep shared prompts first at the end, then sort by length descending.
615+
std::sort(seq.begin(), seq.end(),
616+
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
617+
if (a.n_seq_id == b.n_seq_id) {
618+
return a.length > b.length;
634619
}
620+
return a.n_seq_id < b.n_seq_id;
635621
}
636-
++seq_id_max;
637-
638-
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
639-
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
640-
std::vector<int8_t> seq_id(seq_id_max);
622+
);
623+
}
641624

642625
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
643626
batch = in_batch;
644627
GGML_ASSERT(batch.n_tokens > 0);
628+
629+
if (!batch.pos) {
630+
if (batch.seq_id) {
631+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
632+
return false;
633+
}
634+
}
635+
636+
if (batch.token) {
637+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
638+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
639+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
640+
return false;
641+
}
642+
}
643+
}
644+
645+
if (batch.seq_id) {
646+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
647+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
648+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
649+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
650+
return false;
651+
}
652+
}
653+
}
654+
}
655+
645656
if (!batch.pos) {
646657
assert(p0 >= 0);
647658
pos.resize(batch.n_tokens);
@@ -650,13 +661,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
650661
}
651662
batch.pos = pos.data();
652663
}
664+
653665
if (!batch.n_seq_id) {
654666
n_seq_id.resize(batch.n_tokens);
655667
for (int32_t i = 0; i < batch.n_tokens; i++) {
656668
n_seq_id[i] = seq_id_0.size();
657669
}
658670
batch.n_seq_id = n_seq_id.data();
659671
}
672+
660673
if (!batch.seq_id) {
661674
seq_id.resize(batch.n_tokens + 1);
662675
seq_id[batch.n_tokens] = NULL;
@@ -665,12 +678,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
665678
}
666679
batch.seq_id = seq_id.data();
667680
}
681+
668682
if (!batch.logits) {
669683
// by default return the output only for the last token
670684
output.resize(batch.n_tokens);
671685
output[output.size() - 1] = true;
672686
batch.logits = output.data();
673687
}
688+
689+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
690+
n_outputs += batch.logits[i] != 0;
691+
}
692+
693+
return true;
694+
}
695+
696+
const llama_batch & llama_batch_allocr::get_batch() const {
697+
return batch;
698+
}
699+
700+
uint32_t llama_batch_allocr::get_n_outputs() const {
701+
return n_outputs;
702+
}
703+
704+
void llama_batch_allocr::clear() {
705+
n_outputs = 0;
706+
707+
batch = {};
708+
pos.clear();
709+
n_seq_id.clear();
710+
seq_id.clear();
711+
output.clear();
674712
}
675713

676714
//

src/llama-batch.h

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,12 @@ struct llama_ubatch {
2121
uint32_t n_seqs; // sequence sets in the ubatch
2222
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
2323

24-
// seq_id_unq: unique sequence ids in the ubatch
25-
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26-
// used for extracting sequence pooled embeddings
27-
28-
// // size | idx | val
29-
llama_token * token; // [n_tokens] | i | id, token
30-
float * embd; // [n_embd, n_tokens] | i | embd
31-
llama_pos * pos; // [n_tokens] | i | pos
32-
int32_t * n_seq_id; // [n_tokens] | i | -
33-
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34-
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35-
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36-
int8_t * output; // [n_tokens] | i | -
24+
llama_token * token; // [n_tokens]
25+
float * embd; // [n_embd, n_tokens]
26+
llama_pos * pos; // [n_tokens]
27+
int32_t * n_seq_id; // [n_seqs]
28+
llama_seq_id ** seq_id; // [n_seqs]
29+
int8_t * output; // [n_tokens]
3730
};
3831

3932
struct llama_sbatch_seq {
@@ -91,14 +84,29 @@ struct llama_sbatch {
9184
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
9285
};
9386

94-
// used[i] indicates if token i has already been used in a previous ubatch
95-
std::vector<bool> used;
87+
// temporary allocate memory for the input batch if needed
88+
class llama_batch_allocr {
89+
public:
90+
llama_batch_allocr();
91+
92+
// optionally fulfill the batch returned by llama_batch_get_one
93+
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
94+
95+
const llama_batch & get_batch() const;
96+
97+
uint32_t get_n_outputs() const;
98+
99+
private:
100+
void clear();
101+
102+
llama_batch batch;
103+
104+
uint32_t n_outputs;
96105

97106
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
107+
98108
std::vector<llama_pos> pos;
99109
std::vector<int32_t> n_seq_id;
100110
std::vector<llama_seq_id *> seq_id;
101111
std::vector<int8_t> output;
102-
103-
int debug;
104112
};

src/llama-context.cpp

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ llama_context::llama_context(
2020
const llama_model & model,
2121
llama_context_params params) :
2222
model(model),
23-
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
23+
batch_allocr(std::make_unique<llama_batch_allocr>()) {
2424
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2525

2626
t_start_us = model.t_start_us;
@@ -722,26 +722,23 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
722722
}
723723

724724
int llama_context::encode(const llama_batch & batch_inp) {
725-
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726-
727725
if (batch_inp.n_tokens == 0) {
728726
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
729727
return -1;
730728
}
731729

732-
const auto & hparams = model.hparams;
733-
734-
const int64_t n_embd = hparams.n_embd;
735-
730+
// temporary allocate memory for the input batch if needed
736731
// note: during encode, we always pass the full sequence starting from pos = 0
737-
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
732+
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
738733
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739734
return -1;
740735
}
741736

742-
const uint32_t n_tokens = balloc->get_n_tokens();
737+
const llama_batch & batch = batch_allocr->get_batch();
743738

744-
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
739+
const uint32_t n_tokens = batch.n_tokens;
740+
741+
const int64_t n_embd = hparams.n_embd;
745742

746743
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
747744
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -755,6 +752,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
755752

756753
n_queued_tokens += n_tokens;
757754

755+
const auto & hparams = model.hparams;
756+
758757
const int64_t n_embd = hparams.n_embd;
759758

760759
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
@@ -824,6 +823,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
824823
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
825824
const int32_t seq_idx = ubatch.seq_idx[seq_id];
826825

826+
// TODO: fix indexing [UBATCH_IDX]
827+
for (uint32_t i = 0; i < n_tokens; i++) {
828+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
829+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
830+
continue;
831+
}
827832
embd_seq_out[seq_id].resize(n_embd);
828833
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
829834
}
@@ -835,10 +840,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
835840

836841
const uint32_t n_cls_out = hparams.n_cls_out;
837842

838-
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
839-
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
840-
const int32_t seq_idx = ubatch.seq_idx[seq_id];
841-
843+
// TODO: fix indexing [UBATCH_IDX]
844+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
845+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
846+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
847+
continue;
848+
}
842849
embd_seq_out[seq_id].resize(n_cls_out);
843850
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
844851
}
@@ -868,15 +875,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
868875
const auto & batch = balloc->get_batch();
869876

870877
// remember the sequence ids used during the encoding - needed for cross attention later
871-
// TODO: the seuqence indexing here is likely not correct in the general case
872-
// probably works only for split_simple
873878
cross.seq_ids_enc.resize(n_tokens);
874879
for (uint32_t i = 0; i < n_tokens; i++) {
875880
cross.seq_ids_enc[i].clear();
876-
877881
for (int s = 0; s < batch.n_seq_id[i]; s++) {
878-
const llama_seq_id seq_id = batch.seq_id[i][s];
879-
882+
llama_seq_id seq_id = batch.seq_id[i][s];
880883
cross.seq_ids_enc[i].insert(seq_id);
881884
}
882885
}
@@ -886,57 +889,44 @@ int llama_context::encode(const llama_batch & batch_inp) {
886889
}
887890

888891
int llama_context::decode(const llama_batch & batch_inp) {
889-
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
890-
891892
if (!memory) {
892893
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
893894
return encode(batch_inp);
895+
return encode(batch_inp);
894896
}
895897

896898
if (batch_inp.n_tokens == 0) {
897899
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
898900
return -1;
899901
}
900902

903+
// temporary allocate memory for the input batch if needed
904+
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
905+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
906+
return -1;
907+
}
908+
909+
const llama_batch & batch = batch_allocr->get_batch();
910+
901911
const auto & vocab = model.vocab;
902912
const auto & hparams = model.hparams;
903913

904914
const int32_t n_vocab = vocab.n_tokens();
915+
const int64_t n_embd = hparams.n_embd;
905916

906-
const int64_t n_tokens_all = batch.n_tokens;
907-
const int64_t n_embd = hparams.n_embd;
917+
const uint32_t n_tokens_all = batch.n_tokens;
908918

909919
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
910920

911-
// TODO: move the validation to the llama_batch_allocr
912-
if (batch.token) {
913-
for (int64_t i = 0; i < n_tokens_all; ++i) {
914-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
915-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
916-
return -1;
917-
}
918-
919-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
920-
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
921-
return -1;
922-
}
923-
}
924-
}
925-
926921
// this indicates we are doing pooled embedding
927922
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
928923

929-
int64_t n_outputs_all = 0;
930-
931-
// count outputs
932-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
933-
n_outputs_all += batch.logits[i] != 0;
934-
}
924+
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
935925

936926
if (embd_pooled) {
937927
// require that all tokens are output
938928
if (n_outputs_all != n_tokens_all) {
939-
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
929+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
940930
__func__, n_outputs_all, n_tokens_all);
941931
return -1;
942932
}
@@ -1044,6 +1034,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10441034
pos_min[s] = std::numeric_limits<llama_pos>::max();
10451035
}
10461036

1037+
// TODO: fix sequence indexing
10471038
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
10481039
const auto & seq_id = ubatch.seq_id[i][0];
10491040

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ struct llama_context {
247247
std::map<llama_seq_id, std::vector<float>> embd_seq;
248248

249249
// reuse the batch_allocr to avoid unnecessary memory allocations
250-
std::unique_ptr<llama_batch_allocr> balloc;
250+
std::unique_ptr<llama_batch_allocr> batch_allocr;
251251

252252
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253253

0 commit comments

Comments
 (0)