@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
235
235
236
236
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
237
237
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
238
- llama_batch batch = llama_batch_init (n_ctx, 0 , 1 );
238
+ llama_batch batch = llama_batch_init (n_ctx*n_clients , 0 , 1 );
239
239
240
240
int32_t n_total_prompt = 0 ;
241
241
int32_t n_total_gen = 0 ;
@@ -289,8 +289,11 @@ int main(int argc, char ** argv) {
289
289
// all sequences have ended - clear the entire KV cache
290
290
for (int i = 1 ; i <= n_clients; ++i) {
291
291
llama_memory_seq_rm (mem, i, -1 , -1 );
292
- // but keep the system prompt
293
- llama_memory_seq_cp (mem, 0 , i, -1 , -1 );
292
+
293
+ if (is_sp_shared) {
294
+ // but keep the system prompt
295
+ llama_memory_seq_cp (mem, 0 , i, -1 , -1 );
296
+ }
294
297
}
295
298
296
299
LOG_INF (" %s: clearing the KV cache\n " , __func__);
@@ -449,8 +452,11 @@ int main(int argc, char ** argv) {
449
452
}
450
453
451
454
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
452
- llama_memory_seq_rm (mem, client.id + 1 , -1 , -1 );
453
- llama_memory_seq_cp (mem, 0 , client.id + 1 , -1 , -1 );
455
+ llama_memory_seq_rm (mem, client.id + 1 , -1 , -1 );
456
+
457
+ if (is_sp_shared) {
458
+ llama_memory_seq_cp (mem, 0 , client.id + 1 , -1 , -1 );
459
+ }
454
460
455
461
const auto t_main_end = ggml_time_us ();
456
462
0 commit comments