-
Notifications
You must be signed in to change notification settings - Fork 270
Feature: prompt caching (Fixes #310) #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
50ff67a
to
e3645dc
Compare
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it | ||
/// to be used within the ``ModelContainer`` context. | ||
/// | ||
/// TODO: cache isolation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part is key -- I think will need a lock and a method like:
func withCache<R>(_ block: ([KVCache]) throws -> R) -> rethrows R
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I've forgotten some of the details but I had roadblocks with each approach. I think there was an issue with ModelContainer.perform() being asynchronous and trying to wrap that with something like withCache.
I might have to leave it to someone with more expertise in Swift concurrency.
|
||
/// Remove prefix from prompt that is already in cache | ||
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) { | ||
lmInput = LMInput(text: LMInput.Text(tokens: suffix)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that the KVCache and the prompt should match in size -- that is I thought the prompt should not have the pieces that are already in the KVCache trimmed off. Hrm, perhaps I am confused, here is the LLM prefill code:
while y.tokens.size > prefillStepSize {
let input = y[.newAxis, ..<prefillStepSize]
let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
eval(cache)
y = y[prefillStepSize...]
}
and I think it matches what you are doing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at that code, it looks to me like it is just passing the entire prompt (which is whatever was passed to it) through the model, using the existing cache.
To be able to do trimming the cache or the model would need to have a record of all of the tokens in the cache up to that point, I'm not sure it does? But if it did, the trimming logic could be moved there.
One issue we have at the moment with PromptCache is that it is responsible for recording the tokens represented by the cache. However, because AsyncStream doesn't return tokens we have no way of updating the token list for the cache with the generated response. As a result we always trim the previous response from the new prompt because it doesn't know it is in the cache.
Either AsyncStream should return the tokens (which may not be a bad idea anyway), or the cache moved closer to where the tokens are generated and they can be added there. However KVCache doesn't have a record of the tokens (I don't think?) so that is why we need PromptCache with its tokens MLXArray, so if the cache was to be fully managed within TokenIterator we would need to pass it the full PromptCache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the chat (turn taking) style use of KVCache a little bit more. I don't think we need to observe tokens directly -- KVCache already represents that.
If we want to trim the cache we have a couple of options:
- we can record the index for particular points in the conversation and roll back to those
- we can (I think) tokenize a truncated prompt with the back and forth conversation to get a token count and trim to that offset
If we put the tokens in the AsyncStream that requires the holder of that to use the streaming detokenizer (not Sendable) and complicates things.
Now you may be right -- there may be certain cases where we do need the tokens, but I wonder if that should be handled synchronously inside the TokenIterator?
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop | ||
/// - Returns: the generated output | ||
public func generate( | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some of this is already in place now.
Overall I like this direction. I think it needs:
|
What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key. |
print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)") | ||
|
||
print("cache[\(self.tokens.size)]: \(self.tokens)") | ||
print("prompt[\(prompt.size)]: \(prompt)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- remove debug printing
Yes, passing Just a comment on the streamlined approach and the |
Just to clarify, if the If the So the Qwen3 thinking example may behave differently to examples that remove the |
Reference on the I think this is a little bit complicated, but goes something like this:
We would have to consider what happens if the caller terminates the iteration early. Maybe it isn't even part of the Iterator (ideally we could compose this). Anyway, the |
Something to consider if the
It is possible that a token could be If the If we can't edit the cache at the token level, then the entire previous response has to be removed (as it is guaranteed to be on a clean token boundary), and it is pre-filled again as if a new prompt. Note that the public func trim(n: Int) -> Int {
let toTrim = min(self.offset, n)
self.offset -= toTrim
return toTrim
} |
Yeah, we probably have to handle the FWIW mlx-lm (python side) does not have this capability yet. |
Fixes: #310
Currently there is no way to persist the cache between calls to
generate()
.This is a relatively simple fix by adding
[KVCache]
parameters to thegenerate()
functions which are then passed to theTokenIterator
.Trim functions have been added to the
KVCache
protocol and implemented inKVCacheSimple
. Even though not strictly necessary for caching, it is not uncommon for the new prompt to be partially inconsistent with the cache either through tokenizer inconsistencies or recent messages being intentionally manipulated (e.g. removing a<think>
block).An example of how to implement a prompt cache has been added to MLXChatExample. The time to first token (TTFT) is now also displayed which is helpful to see the performance improvement from caching.
The prompt cache is implemented in the
PromptCache
class which is@unchecked Sendable
which allows it to be used within theModelContainer
context. Currently there is no isolation on theKVCache
inPromptCache
.Note that if using the
AsyncStream
versions ofgenerate()
there is no way to return token ids, so the newly generated response can't be added to the cache, and it will be reprocessed again on the next message. Perhaps the token could be added to theGeneration.chunk
?