Skip to content

llama : add llama_batch_ext #11875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4ed4fe7
first proposal for private llama_batch
ngxson Feb 13, 2025
f2e59a8
rework, targeting llama-server
ngxson Feb 14, 2025
17d3658
move to llama_batch_ext
ngxson Feb 15, 2025
85ef80c
server : use llama_batch_ext
ngxson Feb 15, 2025
aed4a8e
fix server
ngxson Feb 16, 2025
4bf7ca3
llama_decode_ext
ngxson Feb 24, 2025
a1b1dea
Merge branch 'master' into xsn/private_batch_api
ngxson Feb 24, 2025
f0ffd81
adapt common
ngxson Mar 1, 2025
9e75c49
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 1, 2025
40989f4
correct llama_decode_ext
ngxson Mar 1, 2025
1170135
llama_batch_ext_add_text
ngxson Mar 1, 2025
1d6ba97
remove token_info API
ngxson Mar 1, 2025
46596ca
apply various in places
ngxson Mar 1, 2025
17f954c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 13, 2025
86973cb
fix merge errors
ngxson Mar 13, 2025
4aabf4e
return output ID from llama_batch_ext_add/set
ngxson Mar 13, 2025
47086fa
apply to the rest
ngxson Mar 13, 2025
9fb2d81
fix common_batch missing seq_id
ngxson Mar 13, 2025
65f0184
compile ok
ngxson Mar 13, 2025
c3dd790
fix llama_batch_ext_init_from_text
ngxson Mar 13, 2025
04f8641
rm redundant llama_batch_ext_set_output_last
ngxson Mar 13, 2025
54566ad
correct comment
ngxson Mar 13, 2025
bfdddbc
bring back mistakenly deleted llama_batch_init/free
ngxson Mar 13, 2025
5e6a6d4
fix llama-run n_past
ngxson Mar 14, 2025
3294036
fix gemma3-cli
ngxson Mar 14, 2025
07d84fa
fix missing n_past in various places
ngxson Mar 14, 2025
ba79369
fix llama_batch_ext_init_from_embd
ngxson Mar 14, 2025
a363251
qwen2vl: use llama_batch_ext_set_pos
ngxson Mar 14, 2025
8e7714f
fix compile
ngxson Mar 14, 2025
eaffba0
llama_batch_ext_ptr::from_text/embd
ngxson Mar 14, 2025
116b9a1
rename to init_from_text
ngxson Mar 14, 2025
624a683
fix compile
ngxson Mar 14, 2025
de788e0
Update examples/tts/tts.cpp
ngxson Mar 17, 2025
eab5606
Apply suggestions from code review
ngxson Mar 17, 2025
dc4bb64
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 18, 2025
7a3c178
speculative : adapt to new llama API
ggerganov Mar 18, 2025
23d7407
Merge pull request #15 from ggml-org/xsn/private_batch_api
ngxson Mar 19, 2025
b0db7fc
android : adapt to new API
ggerganov Mar 19, 2025
96ca6e8
swift : adapt to new API
ggerganov Mar 19, 2025
32c2c41
android : fix permission
ngxson Mar 19, 2025
6f54ee6
retrieval : avoid common_batch
ggerganov Mar 19, 2025
8b80d68
embedding : avoid common_batch
ggerganov Mar 19, 2025
76fd7d6
perplexity : avoid common_batch
ggerganov Mar 20, 2025
8a23b4a
server : avoid common_batch
ggerganov Mar 20, 2025
b8b1732
server : remove old commented code [no ci]
ggerganov Mar 20, 2025
bd51d63
Merge pull request #16 from ggml-org/xsn/private_batch_api_pooling_none
ngxson Mar 20, 2025
30f1db9
remove C API llama_batch_ext_init_from_text
ngxson Mar 20, 2025
c5a0176
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 21, 2025
2134cab
add cpp batch.add_text wrapper
ngxson Mar 21, 2025
2cec1cf
move various places to batch.add_text
ngxson Mar 21, 2025
3802ff2
add batch.clear() and batch.n_tokens()
ngxson Mar 21, 2025
e8827a6
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 23, 2025
a9efdbb
qwen2vl: fix mrope position
ngxson Mar 23, 2025
1434c2c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 25, 2025
d18a79e
llama_batch_ext_init with ctx
ngxson Mar 25, 2025
c4fea7f
fix qwzn2vl mrope position input
ngxson Mar 25, 2025
42062cc
fix build
ngxson Mar 25, 2025
56e82d0
fix server
ngxson Mar 25, 2025
50fb396
server: fix batch_spec
ngxson Mar 25, 2025
8ec0ff9
fix embeddings and retrieval
ngxson Mar 27, 2025
c1f4a78
correct output_id for llama-cpp header
ngxson Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ examples/server/*.css.hpp
examples/server/*.html.hpp
examples/server/*.js.hpp
examples/server/*.mjs.hpp
examples/server/*.gz.hpp
!build_64.sh
!examples/*.bat
!examples/*/*.kts
Expand Down
21 changes: 8 additions & 13 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
return buf.str();
}

/*
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
std::stringstream buf;

Expand Down Expand Up @@ -614,6 +615,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat

return buf.str();
}
*/

void string_process_escapes(std::string & input) {
std::size_t input_len = input.length();
Expand Down Expand Up @@ -1608,27 +1610,20 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
// Batch utils
//

void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
void common_batch_clear(struct llama_batch * batch) {
llama_batch_clear(batch);
}

void common_batch_add(
struct llama_batch & batch,
struct llama_batch * batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");

batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
if (res == -1) {
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
}
batch.logits [batch.n_tokens] = logits;

batch.n_tokens++;
}

//
Expand Down
4 changes: 2 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
// Batch utils
//

void common_batch_clear(struct llama_batch & batch);
void common_batch_clear(struct llama_batch * batch);

void common_batch_add(
struct llama_batch & batch,
struct llama_batch * batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
Expand Down
6 changes: 3 additions & 3 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;

llama_batch batch;
llama_batch * batch;
llama_tokens prompt;
};

Expand All @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
/* .prompt = */ {},
};

Expand Down Expand Up @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
}

// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
if (llama_batch_get_n_tokens(batch) > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());

llama_decode(ctx, batch);
Expand Down
98 changes: 43 additions & 55 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;

llama_batch batch_spec = {};
llama_batch_ptr batch_spec;

llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
Expand Down Expand Up @@ -1787,7 +1787,7 @@ struct server_context {

llama_context_params cparams_dft;

llama_batch batch = {};
llama_batch_ptr batch;

bool clean_kv_cache = true;
bool add_bos_token = true;
Expand Down Expand Up @@ -1820,11 +1820,7 @@ struct server_context {

common_speculative_free(slot.spec);
slot.spec = nullptr;

llama_batch_free(slot.batch_spec);
}

llama_batch_free(batch);
}

bool load_model(const common_params & params) {
Expand Down Expand Up @@ -1944,7 +1940,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;

if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1));

slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
Expand All @@ -1969,7 +1965,7 @@ struct server_context {

slot.reset();

slots.push_back(slot);
slots.push_back(std::move(slot));
}

default_generation_settings_for_props = slots[0].to_json();
Expand All @@ -1980,7 +1976,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);

// only a single seq_id per token is needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is obsolete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think I removed it in one of the commits above, we don't need n_batch anymore so I removed this whole code block

batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1));
}

metrics.init();
Expand Down Expand Up @@ -2098,9 +2094,7 @@ struct server_context {
}

if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);

slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1));
}

slot.state = SLOT_STATE_STARTED;
Expand Down Expand Up @@ -2408,7 +2402,7 @@ struct server_context {
queue_results.send(std::move(res));
}

void send_embedding(const server_slot & slot, const llama_batch & batch) {
void send_embedding(const server_slot & slot, llama_batch_ptr & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
Expand All @@ -2419,18 +2413,19 @@ struct server_context {

std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
if (!tok.logits || tok.seq_id[0] != slot.id) {
continue;
}

const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}

if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);

res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
Expand All @@ -2451,24 +2446,25 @@ struct server_context {
queue_results.send(std::move(res));
}

void send_rerank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, llama_batch_ptr & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
if (!tok.logits || tok.seq_id[0] != slot.id) {
continue;
}

const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}

if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);

res->score = -1e6;
continue;
Expand Down Expand Up @@ -2859,7 +2855,7 @@ struct server_context {
}

// start populating the batch for this iteration
common_batch_clear(batch);
common_batch_clear(batch.get());

// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
Expand All @@ -2881,9 +2877,9 @@ struct server_context {
continue;
}

slot.i_batch = batch.n_tokens;
slot.i_batch = llama_batch_get_n_tokens(batch.get());

common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true);

slot.n_past += 1;

Expand All @@ -2900,7 +2896,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);

// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
Expand Down Expand Up @@ -3066,7 +3062,7 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
continue;
}
}
Expand All @@ -3086,11 +3082,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand All @@ -3100,13 +3096,13 @@ struct server_context {
slot.n_past++;
}

SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);

// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;

GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0);

common_sampler_reset(slot.smpl);

Expand All @@ -3116,27 +3112,27 @@ struct server_context {
}

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
llama_batch_set_logits_last(batch.get());

slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1;

SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get()));
}
}

if (batch.n_tokens >= n_batch) {
if (llama_batch_get_n_tokens(batch.get()) >= n_batch) {
break;
}
}
}

if (batch.n_tokens == 0) {
if (llama_batch_get_n_tokens(batch.get()) == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}

SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get()));

if (slot_batched) {
// make sure we're in the right embedding mode
Expand All @@ -3146,20 +3142,12 @@ struct server_context {
}

// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i);

llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens));

const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode(ctx, batch_view.get());
metrics.on_decoded(slots);

if (ret != 0) {
Expand Down Expand Up @@ -3294,16 +3282,16 @@ struct server_context {
}

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
common_batch_clear(slot.batch_spec.get());
common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true);
}

SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get()));

llama_decode(ctx, slot.batch_spec);
llama_decode(ctx, slot.batch_spec.get());

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
Expand Down
5 changes: 5 additions & 0 deletions include/llama-cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter {
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
};

struct llama_batch_deleter {
void operator()(llama_batch * batch) { llama_batch_free(batch); }
};

typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr;
Loading
Loading