Skip to content

Commit 3fd6eb4

Browse files
ggerganovMinh141120
authored andcommitted
context : simplify output counting logic during decode (ggml-org#14142)
* batch : remove logits_all flag ggml-ci * context : simplify output counting logic during decode ggml-ci * cont : fix comments
1 parent 3724d37 commit 3fd6eb4

File tree

3 files changed

+66
-47
lines changed

3 files changed

+66
-47
lines changed

src/llama-batch.cpp

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -639,30 +639,37 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
639639
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
640640
std::vector<int8_t> seq_id(seq_id_max);
641641

642-
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
643-
seq_id[ubatch.seq_id[i][s]] = 1;
644-
}
645-
646-
std::stringstream ss;
647-
for (int s = 0; s < seq_id_max; ++s) {
648-
if (seq_id[s]) {
649-
ss << s%10;
650-
} else {
651-
ss << ".";
652-
}
653-
}
654-
655-
if (ubatch.token) {
656-
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
657-
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
658-
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
659-
} else {
660-
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
661-
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
662-
}
663-
}
664-
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
642+
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
643+
batch = in_batch;
644+
GGML_ASSERT(batch.n_tokens > 0);
645+
if (!batch.pos) {
646+
assert(p0 >= 0);
647+
pos.resize(batch.n_tokens);
648+
for (int32_t i = 0; i < batch.n_tokens; i++) {
649+
pos[i] = p0 + i;
665650
}
651+
batch.pos = pos.data();
652+
}
653+
if (!batch.n_seq_id) {
654+
n_seq_id.resize(batch.n_tokens);
655+
for (int32_t i = 0; i < batch.n_tokens; i++) {
656+
n_seq_id[i] = seq_id_0.size();
657+
}
658+
batch.n_seq_id = n_seq_id.data();
659+
}
660+
if (!batch.seq_id) {
661+
seq_id.resize(batch.n_tokens + 1);
662+
seq_id[batch.n_tokens] = NULL;
663+
for (int32_t i = 0; i < batch.n_tokens; i++) {
664+
seq_id[i] = seq_id_0.data();
665+
}
666+
batch.seq_id = seq_id.data();
667+
}
668+
if (!batch.logits) {
669+
// by default return the output only for the last token
670+
output.resize(batch.n_tokens);
671+
output[output.size() - 1] = true;
672+
batch.logits = output.data();
666673
}
667674
}
668675

src/llama-batch.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,11 @@ struct llama_sbatch {
9494
// used[i] indicates if token i has already been used in a previous ubatch
9595
std::vector<bool> used;
9696

97-
// llama_ubatch points to this data:
98-
struct ubatch {
99-
std::vector<llama_token> token;
100-
std::vector<float> embd;
101-
std::vector<llama_pos> pos;
102-
std::vector<int32_t> n_seq_id;
103-
std::vector<llama_seq_id *> seq_id;
104-
std::vector<llama_seq_id> seq_id_unq;
105-
std::vector<int32_t> seq_idx;
106-
std::vector<int8_t> output;
107-
};
108-
109-
// current splitting state:
110-
std::vector<ubatch> ubatches;
97+
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
98+
std::vector<llama_pos> pos;
99+
std::vector<int32_t> n_seq_id;
100+
std::vector<llama_seq_id *> seq_id;
101+
std::vector<int8_t> output;
111102

112103
int debug;
113104
};

src/llama-context.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -902,23 +902,41 @@ int llama_context::decode(const llama_batch & batch_inp) {
902902
const auto & hparams = model.hparams;
903903

904904
const int32_t n_vocab = vocab.n_tokens();
905-
const int64_t n_embd = hparams.n_embd;
906905

907-
// when computing embeddings, all tokens are output
908-
const bool output_all = cparams.embeddings;
906+
const int64_t n_tokens_all = batch.n_tokens;
907+
const int64_t n_embd = hparams.n_embd;
909908

910-
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
911-
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
912-
return -1;
909+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
910+
911+
// TODO: move the validation to the llama_batch_allocr
912+
if (batch.token) {
913+
for (int64_t i = 0; i < n_tokens_all; ++i) {
914+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
915+
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
916+
return -1;
917+
}
918+
919+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
920+
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
921+
return -1;
922+
}
923+
}
913924
}
914925

915-
const uint32_t n_tokens_all = balloc->get_n_tokens();
916-
const uint32_t n_outputs_all = balloc->get_n_outputs();
926+
// this indicates we are doing pooled embedding
927+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
928+
929+
int64_t n_outputs_all = 0;
917930

918-
if (output_all) {
931+
// count outputs
932+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
933+
n_outputs_all += batch.logits[i] != 0;
934+
}
935+
936+
if (embd_pooled) {
919937
// require that all tokens are output
920938
if (n_outputs_all != n_tokens_all) {
921-
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
939+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
922940
__func__, n_outputs_all, n_tokens_all);
923941
return -1;
924942
}
@@ -2045,6 +2063,9 @@ void llama_context::opt_epoch_iter(
20452063

20462064
n_queued_tokens += n_tokens_all;
20472065

2066+
// this indicates we are doing pooled embedding
2067+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2068+
20482069
embd_seq.clear();
20492070

20502071
uint32_t n_outputs_all = n_tokens_all;

0 commit comments

Comments
 (0)