-
Notifications
You must be signed in to change notification settings - Fork 12.4k
batch : rework llama_batch_allocr #14153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-context.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-impl.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-batch.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-io.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-memory.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "llama-mmap.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -18,7 +19,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
llama_context::llama_context( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_model & model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
llama_context_params params) : | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model(model) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model(model), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_allocr(std::make_unique<llama_batch_allocr>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
t_start_us = model.t_start_us; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -494,7 +496,7 @@ float * llama_context::get_logits() { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
float * llama_context::get_logits_ith(int32_t i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t j = -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t j = -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (logits == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (j >= n_outputs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// This should not happen | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return logits + j*model.vocab.n_tokens(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
float * llama_context::get_embeddings_ith(int32_t i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t j = -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t j = -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (embd == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (j >= n_outputs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// This should not happen | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return embd + j*model.hparams.n_embd; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -727,18 +729,19 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// temporary allocate memory for the input batch if needed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// note: during encode, we always pass the full sequence starting from pos = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_batch & batch = batch_allocr.batch; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int32_t n_tokens = batch.n_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_batch & batch = batch_allocr->get_batch(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const uint32_t n_tokens = batch.n_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & hparams = model.hparams; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: move the validation to the llama_batch_allocr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (batch.token) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_tokens; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -775,7 +778,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
return -2; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_tokens; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_ids[i] = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -831,7 +834,8 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_tokens; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: fix sequence indexing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+826
to
827
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @compilade Regarding this comment from earlier, how does this sequence traversal work correctly when the ubatch is created with AFAIU the original meaning of llama.cpp/src/llama-context.cpp Lines 1146 to 1152 in f164ba9
I am planning to rework this in some way, so any suggestions how to improve this logic are welcome. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ggerganov Line 141 in a592c13
This makes traversal which would work correctly with I'm not sure how to make it more obvious while still sharing the same traversal code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I understand that this traversal over the tokens is correct for both split strategies: llama.cpp/src/llama-kv-cache-unified.cpp Lines 816 to 822 in 4c07964
However, if I want to traverse over the unique sequence ids in the ubatch, or traverse over all sequence ids to which a token in the ubatch is assigned, there is no way to do it correctly for both splits. Is this correct? For example, in the snippet above, if I wanted to get the list of all sequence ids of token There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, traversing unique
This is easier, though, and possible by traversing llama.cpp/src/llama-kv-cache-recurrent.cpp Lines 446 to 449 in 26ff368
In that snippet, for (uint32_t s = 0; s < n_seqs; ++s) {
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx];
for (uint32_t k = 0; k < ubatch.n_seq_id[s]; ++k) {
const llama_seq_id seq_id = ubatch->seq_id[s][k]; Although depending on what you need it's also possible to swap the two inner loops: for (uint32_t s = 0; s < n_seqs; ++s) {
for (uint32_t k = 0; k < ubatch.n_seq_id[s]; ++k) {
const llama_seq_id seq_id = ubatch->seq_id[s][k];
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx]; In this situation, you would not need to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, thank you. I think I understand now. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -881,7 +885,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: the seuqence indexing here is likely not correct in the general case | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// probably works only for split_simple | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cross.seq_ids_enc.resize(n_tokens); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_tokens; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cross.seq_ids_enc[i].clear(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
llama_seq_id seq_id = ubatch.seq_id[i][s]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -912,30 +916,30 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// temporary allocate memory for the input batch if needed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_batch & batch = batch_allocr.batch; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const llama_batch & batch = batch_allocr->get_batch(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & vocab = model.vocab; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & hparams = model.hparams; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int32_t n_vocab = vocab.n_tokens(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t n_embd = hparams.n_embd; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t n_tokens_all = batch.n_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int64_t n_embd = hparams.n_embd; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const uint32_t n_tokens_all = batch.n_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: move the validation to the llama_batch_allocr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (batch.token) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int64_t i = 0; i < n_tokens_all; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -944,7 +948,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
// this indicates we are doing pooled embedding | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t n_outputs_all = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t n_outputs_all = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// count outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -954,7 +958,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
if (embd_pooled) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// require that all tokens are output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (n_outputs_all != n_tokens_all) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
__func__, n_outputs_all, n_tokens_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1024,7 +1028,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// reserve output buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (output_reserve(n_outputs_all) < n_outputs_all) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -2; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1063,6 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
pos_min[s] = std::numeric_limits<llama_pos>::max(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: fix sequence indexing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & seq_id = ubatch.seq_id[i][0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1176,14 +1181,14 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
n_outputs = n_outputs_all; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// set output mappings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (n_outputs > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
bool sorted_output = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto & out_ids = mstate->out_ids(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int64_t i = 0; i < n_outputs_all; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int64_t i = 0; i < n_outputs; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t out_id = out_ids[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_ids[out_id] = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (out_id != i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1195,20 +1200,22 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
// note: this is mostly relevant for recurrent models atm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!sorted_output) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const uint32_t n_vocab = model.vocab.n_tokens(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const uint32_t n_embd = model.hparams.n_embd; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const uint64_t n_embd = model.hparams.n_embd; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT((size_t) n_outputs == out_ids.size()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: is there something more efficient which also minimizes swaps? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_outputs - 1; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t j_min = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t j = i + 1; j < n_outputs; ++j) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_outputs - 1; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t j_min = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t j = i + 1; j < n_outputs; ++j) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (out_ids[j] < out_ids[j_min]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
j_min = j; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (j_min == i) { continue; } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (j_min == i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::swap(out_ids[i], out_ids[j_min]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (logits_size > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t k = 0; k < n_vocab; k++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1221,8 +1228,10 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::fill(output_ids.begin(), output_ids.end(), -1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int32_t i = 0; i < n_outputs; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < n_outputs; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_ids[out_ids[i]] = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1242,7 +1251,7 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
// output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t llama_context::output_reserve(int32_t n_outputs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t llama_context::output_reserve(int32_t n_outputs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & hparams = model.hparams; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
const auto & vocab = model.vocab; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1308,8 +1317,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
// set all ids as invalid (negative) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::fill(output_ids.begin(), output_ids.end(), -1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
this->n_outputs = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
this->n_outputs_max = n_outputs_max; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
this->n_outputs = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return n_outputs_max; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1800,14 +1808,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::vector<int32_t> w_output_pos; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT(n_outputs <= n_outputs_max); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_output_pos.resize(n_outputs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// build a more compact representation of the output ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (size_t i = 0; i < n_batch(); ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// map an output id to a position in the batch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t pos = output_ids[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t pos = output_ids[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (pos >= 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ASSERT(pos < n_outputs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
w_output_pos[pos] = i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2082,7 +2088,7 @@ void llama_context::opt_epoch_iter( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
embd_seq.clear(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
int64_t n_outputs_all = n_tokens_all; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t n_outputs_all = n_tokens_all; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2092,7 +2098,7 @@ void llama_context::opt_epoch_iter( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// reserve output buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (output_reserve(n_outputs_all) < n_outputs_all) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
GGML_ABORT("TODO: handle this error"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided against these TODOs because multiple sequences per input token actually has some useful properties that cannot be achieved otherwise (for example see the hellaswag usage). Instead, will add logic to guarantee that the provided ids are valid, utilizing the memory's
seq_pos_min()
andseq_pos_max()
methods.