Skip to content

Feature: prompt caching (Fixes #310) #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jolonf
Copy link

@jolonf jolonf commented May 5, 2025

Fixes: #310

Currently there is no way to persist the cache between calls to generate().

This is a relatively simple fix by adding [KVCache] parameters to the generate() functions which are then passed to the TokenIterator.

Trim functions have been added to the KVCache protocol and implemented in KVCacheSimple. Even though not strictly necessary for caching, it is not uncommon for the new prompt to be partially inconsistent with the cache either through tokenizer inconsistencies or recent messages being intentionally manipulated (e.g. removing a <think> block).

An example of how to implement a prompt cache has been added to MLXChatExample. The time to first token (TTFT) is now also displayed which is helpful to see the performance improvement from caching.

The prompt cache is implemented in the PromptCache class which is @unchecked Sendable which allows it to be used within the ModelContainer context. Currently there is no isolation on the KVCache in PromptCache.

Note that if using the AsyncStream versions of generate() there is no way to return token ids, so the newly generated response can't be added to the cache, and it will be reprocessed again on the next message. Perhaps the token could be added to the Generation.chunk?

@jolonf jolonf mentioned this pull request May 5, 2025
@jolonf jolonf force-pushed the feature/prompt-caching branch from 50ff67a to e3645dc Compare May 5, 2025 05:23
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it
/// to be used within the ``ModelContainer`` context.
///
/// TODO: cache isolation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part is key -- I think will need a lock and a method like:

func withCache<R>(_ block: ([KVCache]) throws -> R) -> rethrows R

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I've forgotten some of the details but I had roadblocks with each approach. I think there was an issue with ModelContainer.perform() being asynchronous and trying to wrap that with something like withCache.

I might have to leave it to someone with more expertise in Swift concurrency.


/// Remove prefix from prompt that is already in cache
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that the KVCache and the prompt should match in size -- that is I thought the prompt should not have the pieces that are already in the KVCache trimmed off. Hrm, perhaps I am confused, here is the LLM prefill code:

        while y.tokens.size > prefillStepSize {
            let input = y[.newAxis, ..<prefillStepSize]
            let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
            eval(cache)
            y = y[prefillStepSize...]
        }

and I think it matches what you are doing here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at that code, it looks to me like it is just passing the entire prompt (which is whatever was passed to it) through the model, using the existing cache.

To be able to do trimming the cache or the model would need to have a record of all of the tokens in the cache up to that point, I'm not sure it does? But if it did, the trimming logic could be moved there.

One issue we have at the moment with PromptCache is that it is responsible for recording the tokens represented by the cache. However, because AsyncStream doesn't return tokens we have no way of updating the token list for the cache with the generated response. As a result we always trim the previous response from the new prompt because it doesn't know it is in the cache.

Either AsyncStream should return the tokens (which may not be a bad idea anyway), or the cache moved closer to where the tokens are generated and they can be added there. However KVCache doesn't have a record of the tokens (I don't think?) so that is why we need PromptCache with its tokens MLXArray, so if the cache was to be fully managed within TokenIterator we would need to pass it the full PromptCache.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the chat (turn taking) style use of KVCache a little bit more. I don't think we need to observe tokens directly -- KVCache already represents that.

If we want to trim the cache we have a couple of options:

  • we can record the index for particular points in the conversation and roll back to those
  • we can (I think) tokenize a truncated prompt with the back and forth conversation to get a token count and trim to that offset

If we put the tokens in the AsyncStream that requires the holder of that to use the streaming detokenizer (not Sendable) and complicates things.

Now you may be right -- there may be certain cases where we do need the tokens, but I wonder if that should be handled synchronously inside the TokenIterator?

/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
/// - Returns: the generated output
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext,
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of this is already in place now.

@davidkoski
Copy link
Collaborator

Overall I like this direction. I think it needs:

  • finish the borrowing of the KVCache (implement a visitor that holds a lock so we can satisfy the unchecked Sendable)
  • remove the debug printing

@davidkoski
Copy link
Collaborator

What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.

print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)")

print("cache[\(self.tokens.size)]: \(self.tokens)")
print("prompt[\(prompt.size)]: \(prompt)")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • remove debug printing

@jolonf
Copy link
Author

jolonf commented Jun 25, 2025

What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.

Yes, passing KVCache to generate is now there which is the critical part. It at least allows developers to add their own prompt cache handling.

Just a comment on the streamlined approach and the EvaluateLLM example which defaults to Qwen3: Qwen specifically state that <think> blocks should not be included in the chat history. If they aren't included then the KVCache needs to be trimmed because it will have included the most recent <think> block in the last generated response. So technically the example should have cache trimming, but I don't think it is an issue for EvaluateLLM. ChatSession should probably have it though?

@jolonf
Copy link
Author

jolonf commented Jun 25, 2025

What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.

Yes, passing KVCache to generate is now there which is the critical part. It at least allows developers to add their own prompt cache handling.

Just a comment on the streamlined approach and the EvaluateLLM example which defaults to Qwen3: Qwen specifically state that <think> blocks should not be included in the chat history. If they aren't included then the KVCache needs to be trimmed because it will have included the most recent <think> block in the last generated response. So technically the example should have cache trimming, but I don't think it is an issue for EvaluateLLM. ChatSession should probably have it though?

Just to clarify, if the <think> blocks are removed then the LLM will respond in the next turn with no knowledge of any prior <think> blocks.

If the KVCache isn't trimmed then all of the prior <think> blocks will be there and the LLM will be aware of all prior <think> blocks.

So the Qwen3 thinking example may behave differently to examples that remove the <think> blocks. I haven't tested both ways to see how they respond differently.

@davidkoski
Copy link
Collaborator

Reference on the <think> tags: https://api-docs.deepseek.com/guides/reasoning_model

I think this is a little bit complicated, but goes something like this:

  • iterator keeps track of start index when doing generation
  • produces full output
  • goes back through tokens and filters out the <think> section(s)
  • resets the KVCache to the start index
  • prefills the KVCache with the response (like a prompt)

We would have to consider what happens if the caller terminates the iteration early. Maybe it isn't even part of the Iterator (ideally we could compose this).

Anyway, the <think> section will be in the KVCache as a consequence of generation so we have to replay the output to replace what is in there.

@jolonf
Copy link
Author

jolonf commented Jun 26, 2025

Something to consider if the <think> blocks are to be removed at the token level is that the end <\think> tag could form a token with the following response. For example if the response looked like:

<think>
Thinking... thinking...
</think>
Here is my response

It is possible that a token could be >Here, which wouldn't allow the <think> block to be properly removed at the token level.

If the <think> block could be removed at the token level an even better optimisation would be to snip the <think> block out, not just trim. But I don't know enough about LLMs and KV caches to know if this would work.

If we can't edit the cache at the token level, then the entire previous response has to be removed (as it is guaranteed to be on a clean token boundary), and it is pre-filled again as if a new prompt.

Note that the KVCache trimming is extremely efficient, it just edits the offset in the current implementation in the PR:

    public func trim(n: Int) -> Int {
        let toTrim = min(self.offset, n)
        self.offset -= toTrim
        return toTrim
    }

@davidkoski
Copy link
Collaborator

Yeah, we probably have to handle the <think> tags in decoded token space, trim it out, then re-tokenize.

FWIW mlx-lm (python side) does not have this capability yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature: Prompt cache
2 participants