Skip to content

llama : add llama_batch_ext #11875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4ed4fe7
first proposal for private llama_batch
ngxson Feb 13, 2025
f2e59a8
rework, targeting llama-server
ngxson Feb 14, 2025
17d3658
move to llama_batch_ext
ngxson Feb 15, 2025
85ef80c
server : use llama_batch_ext
ngxson Feb 15, 2025
aed4a8e
fix server
ngxson Feb 16, 2025
4bf7ca3
llama_decode_ext
ngxson Feb 24, 2025
a1b1dea
Merge branch 'master' into xsn/private_batch_api
ngxson Feb 24, 2025
f0ffd81
adapt common
ngxson Mar 1, 2025
9e75c49
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 1, 2025
40989f4
correct llama_decode_ext
ngxson Mar 1, 2025
1170135
llama_batch_ext_add_text
ngxson Mar 1, 2025
1d6ba97
remove token_info API
ngxson Mar 1, 2025
46596ca
apply various in places
ngxson Mar 1, 2025
17f954c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 13, 2025
86973cb
fix merge errors
ngxson Mar 13, 2025
4aabf4e
return output ID from llama_batch_ext_add/set
ngxson Mar 13, 2025
47086fa
apply to the rest
ngxson Mar 13, 2025
9fb2d81
fix common_batch missing seq_id
ngxson Mar 13, 2025
65f0184
compile ok
ngxson Mar 13, 2025
c3dd790
fix llama_batch_ext_init_from_text
ngxson Mar 13, 2025
04f8641
rm redundant llama_batch_ext_set_output_last
ngxson Mar 13, 2025
54566ad
correct comment
ngxson Mar 13, 2025
bfdddbc
bring back mistakenly deleted llama_batch_init/free
ngxson Mar 13, 2025
5e6a6d4
fix llama-run n_past
ngxson Mar 14, 2025
3294036
fix gemma3-cli
ngxson Mar 14, 2025
07d84fa
fix missing n_past in various places
ngxson Mar 14, 2025
ba79369
fix llama_batch_ext_init_from_embd
ngxson Mar 14, 2025
a363251
qwen2vl: use llama_batch_ext_set_pos
ngxson Mar 14, 2025
8e7714f
fix compile
ngxson Mar 14, 2025
eaffba0
llama_batch_ext_ptr::from_text/embd
ngxson Mar 14, 2025
116b9a1
rename to init_from_text
ngxson Mar 14, 2025
624a683
fix compile
ngxson Mar 14, 2025
de788e0
Update examples/tts/tts.cpp
ngxson Mar 17, 2025
eab5606
Apply suggestions from code review
ngxson Mar 17, 2025
dc4bb64
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 18, 2025
7a3c178
speculative : adapt to new llama API
ggerganov Mar 18, 2025
23d7407
Merge pull request #15 from ggml-org/xsn/private_batch_api
ngxson Mar 19, 2025
b0db7fc
android : adapt to new API
ggerganov Mar 19, 2025
96ca6e8
swift : adapt to new API
ggerganov Mar 19, 2025
32c2c41
android : fix permission
ngxson Mar 19, 2025
6f54ee6
retrieval : avoid common_batch
ggerganov Mar 19, 2025
8b80d68
embedding : avoid common_batch
ggerganov Mar 19, 2025
76fd7d6
perplexity : avoid common_batch
ggerganov Mar 20, 2025
8a23b4a
server : avoid common_batch
ggerganov Mar 20, 2025
b8b1732
server : remove old commented code [no ci]
ggerganov Mar 20, 2025
bd51d63
Merge pull request #16 from ggml-org/xsn/private_batch_api_pooling_none
ngxson Mar 20, 2025
30f1db9
remove C API llama_batch_ext_init_from_text
ngxson Mar 20, 2025
c5a0176
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 21, 2025
2134cab
add cpp batch.add_text wrapper
ngxson Mar 21, 2025
2cec1cf
move various places to batch.add_text
ngxson Mar 21, 2025
3802ff2
add batch.clear() and batch.n_tokens()
ngxson Mar 21, 2025
e8827a6
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 23, 2025
a9efdbb
qwen2vl: fix mrope position
ngxson Mar 23, 2025
1434c2c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 25, 2025
d18a79e
llama_batch_ext_init with ctx
ngxson Mar 25, 2025
c4fea7f
fix qwzn2vl mrope position input
ngxson Mar 25, 2025
42062cc
fix build
ngxson Mar 25, 2025
56e82d0
fix server
ngxson Mar 25, 2025
50fb396
server: fix batch_spec
ngxson Mar 25, 2025
8ec0ff9
fix embeddings and retrieval
ngxson Mar 27, 2025
c1f4a78
correct output_id for llama-cpp header
ngxson Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (llama_model_has_encoder(model)) {
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), tmp.size(), 0, 0, true);
llama_encode_ext(lctx, batch.get());
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
Expand All @@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
llama_decode_ext(lctx, batch.get());
}
llama_kv_self_clear(lctx);
Expand Down
2 changes: 1 addition & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
/* .batch = */ llama_batch_ext_ptr(ctx_dft),
/* .prompt = */ {},
};

Expand Down
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ int main(int argc, char ** argv) {

const int32_t n_kv_max = llama_n_ctx(ctx);

llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ int main(int argc, char ** argv) {

// create a llama_batch
// we use this object to submit token data for decoding
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion examples/cvector-generator/cvector-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {

static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_self_clear(ctx);
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ int main(int argc, char ** argv) {

// initialize batch
const int n_prompts = prompts.size();
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

// count number of embeddings
int n_embd_count = 0;
Expand Down
2 changes: 1 addition & 1 deletion examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {

std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
Expand Down
4 changes: 2 additions & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
llama_batch_ext_ptr batch(ctx);

for (uint64_t i = 0; i < sentences.size(); i++) {
batch.clear();
Expand Down Expand Up @@ -105,7 +105,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
llama_batch_ext_ptr batch(ctx);

std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
Expand Down
2 changes: 1 addition & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_self_clear(ctx);

llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
Expand Down
4 changes: 2 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), n_tokens, n_past + n_processed, 0, true);
llama_decode_ext(ctx, batch.get());
n_processed += n_tokens;
}
Expand All @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;

for (int i = 0; i < n_gen; i++) {
auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &token, 1, n_past + i, 0, true);
llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx);
token = std::rand() % n_vocab;
Expand Down
5 changes: 3 additions & 2 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(

extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jlong context_pointer) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
llama_batch_ext * batch = llama_batch_ext_init(context);

return reinterpret_cast<jlong>(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class LLamaAndroid {
private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean)
private external fun backend_free()
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun new_batch(context: Long): Long
private external fun free_batch(batch: Long)
private external fun new_sampler(): Long
private external fun free_sampler(sampler: Long)
Expand Down Expand Up @@ -102,7 +102,7 @@ class LLamaAndroid {
val context = new_context(model)
if (context == 0L) throw IllegalStateException("new_context() failed")

val batch = new_batch(512, 0, 1)
val batch = new_batch(context)
if (batch == 0L) throw IllegalStateException("new_batch() failed")

val sampler = new_sampler()
Expand Down
2 changes: 1 addition & 1 deletion examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ actor LlamaContext {
self.model = model
self.context = context
self.tokens_list = []
self.batch = llama_batch_ext_init(512, 1)
self.batch = llama_batch_ext_init(context)
self.temporary_invalid_cchars = []
let sparams = llama_sampler_chain_default_params()
self.sampling = llama_sampler_chain_init(sparams)
Expand Down
5 changes: 3 additions & 2 deletions examples/llava/gemma3-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct gemma3_context {
lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads;
batch.reset(llama_batch_ext_init(params.n_batch, 1));
batch.reset(llama_batch_ext_init(lctx));
init_clip_model(params);
}

Expand Down Expand Up @@ -147,7 +147,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false);
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
llama_batch_ext_ptr batch_img = llama_batch_ext_ptr::init_from_embd(
ctx.lctx, image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0);
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
LOG_ERR("failed to decode image\n");
return 1;
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0);
auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, embd, n_eval, n_embd, 0, 0);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
Expand Down
9 changes: 4 additions & 5 deletions examples/llava/qwen2vl-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));

float * batch_embd = image_embed->embed+i*n_embd;
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval * 4);
const llama_pos * pos = batch_mrope_pos.data();
auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0);

if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
Expand Down Expand Up @@ -97,12 +97,11 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
pos[j] = *st_pos_id + (j % n_eval);
}

llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
llama_batch_ext_ptr batch(ctx_llama);
for (int j = 0; j < n_eval; j++) {
llama_token token = tokens[i + j];
batch.add_text(token, 0, 0, false); // position is set in the next step
batch.add_text(token, *st_pos_id + i + j, 0, false);
}
llama_batch_ext_set_pos(batch.get(), pos.data(), pos.size());
llama_batch_ext_set_output_last(batch.get());

if (llama_decode_ext(ctx_llama, batch.get())) {
Expand Down
6 changes: 3 additions & 3 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();

// eval the prompt
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true);
auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true);
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());

Expand All @@ -117,7 +117,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
Expand Down
6 changes: 3 additions & 3 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ int main(int argc, char ** argv){

const auto t_enc_start = ggml_time_us();

auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true);
auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true);
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());

Expand All @@ -111,7 +111,7 @@ int main(int argc, char ** argv){

std::vector<llama_token> draft;

llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(params.n_ctx, 1));
llama_batch_ext_ptr batch_tgt(ctx);

// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
Expand Down
4 changes: 2 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();

auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, enc_input_buf, enc_input_size, 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
Expand Down Expand Up @@ -669,7 +669,7 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
Expand Down
2 changes: 1 addition & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ int main(int argc, char ** argv) {

// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 1));
llama_batch_ext_ptr batch(ctx);

int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
Expand Down
2 changes: 1 addition & 1 deletion examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_INF("prompt tokens: %d\n", n_tokens_all);
//LOG_INF("prompt: %s\n", params.prompt.c_str());

llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
llama_batch_ext_ptr batch(ctx);

int n_past = 0;

Expand Down
12 changes: 6 additions & 6 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// clear the KV cache
llama_kv_self_clear(ctx);

llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
llama_batch_ext_ptr batch(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -501,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);

llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1));
llama_batch_ext_ptr batch(ctx);

std::vector<float> logits;
if (num_batches > 1) {
Expand Down Expand Up @@ -830,7 +830,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));

llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4));
llama_batch_ext_ptr batch(ctx);

std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
Expand Down Expand Up @@ -1112,7 +1112,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));

llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2));
llama_batch_ext_ptr batch(ctx);

std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
Expand Down Expand Up @@ -1465,7 +1465,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));

llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq));
llama_batch_ext_ptr batch(ctx);

std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
Expand Down Expand Up @@ -1730,7 +1730,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_self_clear(ctx);

llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
llama_batch_ext_ptr batch(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ int main(int argc, char ** argv) {

// initialize batch
const int n_chunks = chunks.size();
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

// allocate output
const int n_embd = llama_model_n_embd(model);
Expand Down Expand Up @@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
chunks[i].tokens.clear();
}

llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1);
llama_batch_ext * query_batch = llama_batch_ext_init(ctx);

// start loop, receive query and return top k similar chunks based on cosine similarity
std::string query;
Expand Down
Loading
Loading