Skip to content

Commit 49d865f

Browse files
committed
mamba : adapt perplexity, batched, and batched-bench examples
* perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions.
1 parent 5403625 commit 49d865f

File tree

5 files changed

+21
-9
lines changed

5 files changed

+21
-9
lines changed

examples/batched-bench/batched-bench.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
105105
ctx_params.n_threads = params.n_threads;
106106
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
107107

108+
// ensure enough sequences are available
109+
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
110+
108111
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
109112

110113
if (ctx == NULL) {
@@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
174177

175178
llama_batch_clear(batch);
176179

177-
const int n_tokens = is_pp_shared ? pp : pl*pp;
178-
179-
for (int i = 0; i < n_tokens; ++i) {
180-
llama_batch_add(batch, 0, i, { 0 }, false);
180+
for (int i = 0; i < pp; ++i) {
181+
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
182+
llama_batch_add(batch, 0, i, { j }, false);
183+
}
181184
}
182185
batch.logits[batch.n_tokens - 1] = true;
183186

@@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
192195

193196
if (is_pp_shared) {
194197
for (int32_t i = 1; i < pl; ++i) {
195-
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
198+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
196199
}
197200
}
198201

examples/batched/batched.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
8080
ctx_params.seed = 1234;
8181
ctx_params.n_ctx = n_kv_req;
8282
ctx_params.n_batch = std::max(n_len, n_parallel);
83+
ctx_params.n_parallel = n_parallel;
8384
ctx_params.n_threads = params.n_threads;
8485
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
8586

@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
132133
// assign the system KV cache to all parallel sequences
133134
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
134135
for (int32_t i = 1; i < n_parallel; ++i) {
135-
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
136+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
136137
}
137138

138139
if (n_parallel > 1) {

examples/perplexity/perplexity.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
809809
const int n_batch = params.n_batch;
810810

811811
const int max_tasks_per_batch = 32;
812-
const int max_seq = 4*max_tasks_per_batch;
812+
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
813813

814814
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
815815

@@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
10861086
const int n_batch = params.n_batch;
10871087

10881088
const int max_tasks_per_batch = 128;
1089-
const int max_seq = 2*max_tasks_per_batch;
1089+
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
10901090

10911091
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
10921092

@@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
14381438
const int n_batch = params.n_batch;
14391439

14401440
const int max_tasks_per_batch = 32;
1441-
const int max_seq = 4*max_tasks_per_batch;
1441+
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
14421442

14431443
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
14441444

@@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
18151815
llama_model * model;
18161816
llama_context * ctx;
18171817

1818+
// ensure there's at least enough seq_ids for HellaSwag
1819+
params.n_parallel = std::max(4, params.n_parallel);
1820+
18181821
// load the model and apply lora adapter, if any
18191822
std::tie(model, ctx) = llama_init_from_gpt_params(params);
18201823
if (model == NULL) {

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12821,6 +12821,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
1282112821
return ctx->cparams.n_batch;
1282212822
}
1282312823

12824+
uint32_t llama_n_max_seq(const struct llama_context * ctx) {
12825+
return ctx->kv_self.size;
12826+
}
12827+
1282412828
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
1282512829
return model->vocab.type;
1282612830
}

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ extern "C" {
368368

369369
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
370370
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
371+
LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx);
371372

372373
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
373374
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);

0 commit comments

Comments
 (0)