Skip to content

batch : rework llama_batch_allocr #14153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,15 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
);
}

llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
llama_batch_allocr::llama_batch_allocr() = default;

bool llama_batch_allocr::init(struct llama_batch in_batch, llama_pos p0) {
GGML_ASSERT(in_batch.n_tokens > 0);

clear();

batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);

if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
Expand All @@ -290,13 +296,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.pos = pos.data();
}

if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}

if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
Expand All @@ -305,12 +313,27 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}

if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
batch.logits = output.data();
}

return true;
}

const llama_batch & llama_batch_allocr::get_batch() const {
return batch;
}

void llama_batch_allocr::clear() {
batch = {};
pos.clear();
n_seq_id.clear();
seq_id.clear();
output.clear();
}

//
Expand Down
22 changes: 15 additions & 7 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ struct llama_ubatch {
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
Comment on lines -21 to -22
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided against these TODOs because multiple sequences per input token actually has some useful properties that cannot be achieved otherwise (for example see the hellaswag usage). Instead, will add logic to guarantee that the provided ids are valid, utilizing the memory's seq_pos_min() and seq_pos_max() methods.

int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
};

Expand Down Expand Up @@ -78,15 +78,23 @@ struct llama_sbatch {
};

// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
struct llama_batch batch;
class llama_batch_allocr {
public:
llama_batch_allocr();

// optionally fulfill the batch returned by llama_batch_get_one
bool init(llama_batch in_batch, llama_pos p0);

const llama_batch & get_batch() const;

private:
void clear();

llama_batch batch;

std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;

// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
};
84 changes: 45 additions & 39 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llama-context.h"

#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h"
Expand All @@ -18,7 +19,8 @@
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model) {
model(model),
batch_allocr(std::make_unique<llama_batch_allocr>()) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);

t_start_us = model.t_start_us;
Expand Down Expand Up @@ -494,7 +496,7 @@ float * llama_context::get_logits() {
}

float * llama_context::get_logits_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;

try {
if (logits == nullptr) {
Expand All @@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}

return logits + j*model.vocab.n_tokens();
Expand All @@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
}

float * llama_context::get_embeddings_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;

try {
if (embd == nullptr) {
Expand All @@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}

return embd + j*model.hparams.n_embd;
Expand Down Expand Up @@ -727,18 +729,19 @@ int llama_context::encode(llama_batch & inp_batch) {

// temporary allocate memory for the input batch if needed
// note: during encode, we always pass the full sequence starting from pos = 0
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : 0);

const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
const llama_batch & batch = batch_allocr->get_batch();

const uint32_t n_tokens = batch.n_tokens;

const auto & hparams = model.hparams;

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
Expand Down Expand Up @@ -775,7 +778,7 @@ int llama_context::encode(llama_batch & inp_batch) {
return -2;
};

for (int32_t i = 0; i < n_tokens; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}

Expand Down Expand Up @@ -831,7 +834,8 @@ int llama_context::encode(llama_batch & inp_batch) {

GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits

for (int32_t i = 0; i < n_tokens; i++) {
// TODO: fix sequence indexing
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
Comment on lines +826 to 827
Copy link
Member Author

@ggerganov ggerganov Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@compilade Regarding this comment from earlier, how does this sequence traversal work correctly when the ubatch is created with split_simple()?

AFAIU the original meaning of ubatch.seq_id[i][j] was "the jth sequence of the ith token". With split_equal(), this now changes to "the ith sequence and j == 0". What is not clear to me is if I used split_simple() how could the sequence traversal be correct?

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);

I am planning to rework this in some way, so any suggestions how to improve this logic are welcome.

Copy link
Collaborator

@compilade compilade Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov
With split_simple(), an invariant is that ubatch.n_seqs == n_tokens, and ubatch.n_seq_tokens == 1, because the sequences are not aggregated.

ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits

This makes traversal which would work correctly with split_equal also be correct with split_simple, even though the seq_ids are definitely repeated (when ubatch.equal_seqs == false, ubatch.n_seqs doesn't really map to distinct sequences).

I'm not sure how to make it more obvious while still sharing the same traversal code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I understand that this traversal over the tokens is correct for both split strategies:

for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx];

However, if I want to traverse over the unique sequence ids in the ubatch, or traverse over all sequence ids to which a token in the ubatch is assigned, there is no way to do it correctly for both splits. Is this correct?

For example, in the snippet above, if I wanted to get the list of all sequence ids of token idx, there is no way to do it without checking the ubatch.equal_seqs. Correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I want to traverse over the unique sequence ids in the ubatch

Yes, traversing unique seq_ids with simple splits (when ubatch.equal_seqs == false) is a bit more complicated, because they are not aggregated (simple splits are plain slices of the user-provided batch).

traverse over all sequence ids to which a token in the ubatch is assigned

This is easier, though, and possible by traversing ubatch.seq_id[s][_] with ubatch.n_seq_id[s]. For example:

for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t n_seq_id = ubatch.n_seq_id[s];
for (uint32_t j = 0; j < n_seq_id; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j];


for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx];

[...]

For example, in the snippet above, if I wanted to get the list of all sequence ids of token idx

In that snippet, seq_id would need to be defined later:

 for (uint32_t s = 0; s < n_seqs; ++s) {

     for (uint32_t j = 0; j < n_seq_tokens; ++j) {
         const uint32_t idx = s*n_seq_tokens + j;

         const llama_pos p1 = ubatch->pos[idx];

         for (uint32_t k = 0; k < ubatch.n_seq_id[s]; ++k) {
             const llama_seq_id seq_id = ubatch->seq_id[s][k];

Although depending on what you need it's also possible to swap the two inner loops:

 for (uint32_t s = 0; s < n_seqs; ++s) {

     for (uint32_t k = 0; k < ubatch.n_seq_id[s]; ++k) {
         const llama_seq_id seq_id = ubatch->seq_id[s][k];

         for (uint32_t j = 0; j < n_seq_tokens; ++j) {
             const uint32_t idx = s*n_seq_tokens + j;

             const llama_pos p1 = ubatch->pos[idx];

In this situation, you would not need to check ubatch.equal_seqs unless unique sequences are required.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you. I think I understand now.

if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
Expand Down Expand Up @@ -881,7 +885,7 @@ int llama_context::encode(llama_batch & inp_batch) {
// TODO: the seuqence indexing here is likely not correct in the general case
// probably works only for split_simple
cross.seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
Expand Down Expand Up @@ -912,30 +916,30 @@ int llama_context::decode(llama_batch & inp_batch) {
}

// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);

const llama_batch & batch = batch_allocr.batch;
const llama_batch & batch = batch_allocr->get_batch();

const auto & vocab = model.vocab;
const auto & hparams = model.hparams;

const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;

const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
const uint32_t n_tokens_all = batch.n_tokens;

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}

if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
return -1;
}
}
Expand All @@ -944,7 +948,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

int64_t n_outputs_all = 0;
uint32_t n_outputs_all = 0;

// count outputs
for (uint32_t i = 0; i < n_tokens_all; ++i) {
Expand All @@ -954,7 +958,7 @@ int llama_context::decode(llama_batch & inp_batch) {
if (embd_pooled) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
__func__, n_outputs_all, n_tokens_all);
return -1;
}
Expand Down Expand Up @@ -1024,7 +1028,7 @@ int llama_context::decode(llama_batch & inp_batch) {

// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};

Expand Down Expand Up @@ -1063,6 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}

// TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];

Expand Down Expand Up @@ -1176,14 +1181,14 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs = n_outputs_all;

// set output mappings
{
if (n_outputs > 0) {
bool sorted_output = true;

auto & out_ids = mstate->out_ids();

GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);

for (int64_t i = 0; i < n_outputs_all; ++i) {
for (int64_t i = 0; i < n_outputs; ++i) {
int64_t out_id = out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
Expand All @@ -1195,20 +1200,22 @@ int llama_context::decode(llama_batch & inp_batch) {
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
const uint64_t n_embd = model.hparams.n_embd;

GGML_ASSERT((size_t) n_outputs == out_ids.size());

// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
uint32_t j_min = i;
for (uint32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
if (j_min == i) {
continue;
}
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
Expand All @@ -1221,8 +1228,10 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}
}

std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {

for (uint32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
}
Expand All @@ -1242,7 +1251,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// output
//

int32_t llama_context::output_reserve(int32_t n_outputs) {
uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;

Expand Down Expand Up @@ -1308,8 +1317,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);

this->n_outputs = 0;
this->n_outputs_max = n_outputs_max;
this->n_outputs = 0;

return n_outputs_max;
}
Expand Down Expand Up @@ -1800,14 +1808,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {

std::vector<int32_t> w_output_pos;

GGML_ASSERT(n_outputs <= n_outputs_max);

w_output_pos.resize(n_outputs);

// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
Expand Down Expand Up @@ -2082,7 +2088,7 @@ void llama_context::opt_epoch_iter(

embd_seq.clear();

int64_t n_outputs_all = n_tokens_all;
uint32_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
Expand All @@ -2092,7 +2098,7 @@ void llama_context::opt_epoch_iter(

// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};

Expand Down
Loading
Loading