@@ -2955,7 +2955,8 @@ struct server_context {
2955
2955
llama_kv_self_seq_rm (ctx, slot.id , n_keep , n_keep + n_discard);
2956
2956
llama_kv_self_seq_add (ctx, slot.id , n_keep + n_discard, slot.n_past , -n_discard);
2957
2957
2958
- if (slot.params .cache_prompt ) {
2958
+ // add generated tokens to cache
2959
+ {
2959
2960
llama_tokens new_tokens = slot.cache_tokens .get_text_tokens (); // copy
2960
2961
for (size_t i = n_keep + n_discard; i < new_tokens.size (); i++) {
2961
2962
new_tokens[i - n_discard] = new_tokens[i];
@@ -3000,10 +3001,7 @@ struct server_context {
3000
3001
common_batch_add (batch, slot.sampled , slot.n_past , { slot.id }, true );
3001
3002
3002
3003
slot.n_past += 1 ;
3003
-
3004
- if (slot.params .cache_prompt ) {
3005
- slot.cache_tokens .push_back (slot.sampled );
3006
- }
3004
+ slot.cache_tokens .push_back (slot.sampled );
3007
3005
3008
3006
SLT_DBG (slot, " slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n " ,
3009
3007
slot.n_ctx , slot.n_past , (int ) slot.cache_tokens .size (), slot.truncated );
@@ -3175,6 +3173,11 @@ struct server_context {
3175
3173
3176
3174
SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
3177
3175
}
3176
+ } else {
3177
+ // if we don't cache the prompt, we have to remove the entire KV cache
3178
+ llama_kv_self_seq_rm (ctx, slot.id , 0 , -1 );
3179
+ slot.n_past = 0 ;
3180
+ slot.cache_tokens .clear ();
3178
3181
}
3179
3182
}
3180
3183
@@ -3208,7 +3211,7 @@ struct server_context {
3208
3211
SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
3209
3212
3210
3213
// remove the non-common part from the cache
3211
- slot.cache_tokens .resize (slot.n_past );
3214
+ slot.cache_tokens .keep_first (slot.n_past );
3212
3215
3213
3216
// check if we should process the image
3214
3217
if (slot.n_past < slot.n_prompt_tokens
@@ -3225,7 +3228,8 @@ struct server_context {
3225
3228
continue ;
3226
3229
}
3227
3230
3228
- if (slot.params .cache_prompt ) {
3231
+ // add the image chunk to cache
3232
+ {
3229
3233
const auto & chunk = slot.prompt_tokens .find_chunk (slot.n_past );
3230
3234
slot.cache_tokens .push_back (chunk.get ()); // copy
3231
3235
}
@@ -3246,9 +3250,7 @@ struct server_context {
3246
3250
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
3247
3251
3248
3252
common_batch_add (batch, cur_tok, slot.n_past , { slot.id }, need_embd);
3249
- if (slot.params .cache_prompt ) {
3250
- slot.cache_tokens .push_back (cur_tok);
3251
- }
3253
+ slot.cache_tokens .push_back (cur_tok);
3252
3254
3253
3255
slot.n_prompt_tokens_processed ++;
3254
3256
slot.n_past ++;
0 commit comments