Skip to content

Commit 1321439

Browse files
committed
tools : tmp adjustments (TMP)
ggml-ci
1 parent 52b9007 commit 1321439

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

examples/parallel/parallel.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
235235

236236
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
237237
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
238-
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
238+
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
239239

240240
int32_t n_total_prompt = 0;
241241
int32_t n_total_gen = 0;
@@ -289,8 +289,11 @@ int main(int argc, char ** argv) {
289289
// all sequences have ended - clear the entire KV cache
290290
for (int i = 1; i <= n_clients; ++i) {
291291
llama_memory_seq_rm(mem, i, -1, -1);
292-
// but keep the system prompt
293-
llama_memory_seq_cp(mem, 0, i, -1, -1);
292+
293+
if (is_sp_shared) {
294+
// but keep the system prompt
295+
llama_memory_seq_cp(mem, 0, i, -1, -1);
296+
}
294297
}
295298

296299
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -449,8 +452,11 @@ int main(int argc, char ** argv) {
449452
}
450453

451454
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
452-
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
453-
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
455+
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
456+
457+
if (is_sp_shared) {
458+
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
459+
}
454460

455461
const auto t_main_end = ggml_time_us();
456462

tools/batched-bench/batched-bench.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
6161

6262
const int32_t n_kv_max = llama_n_ctx(ctx);
6363

64-
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
64+
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
6565

6666
// decode in batches of ctx_params.n_batch tokens
6767
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -119,9 +119,9 @@ int main(int argc, char ** argv) {
119119

120120
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
121121

122-
if (n_ctx_req > n_kv_max) {
123-
continue;
124-
}
122+
//if (n_ctx_req > n_kv_max) {
123+
// continue;
124+
//}
125125

126126
common_batch_clear(batch);
127127

0 commit comments

Comments
 (0)