-
Notifications
You must be signed in to change notification settings - Fork 270
Feature: prompt caching (Fixes #310) #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// | ||
// PromptCache.swift | ||
// mlx-swift-examples | ||
// | ||
// Created by Jolon Faichney on 3/5/2025. | ||
// | ||
|
||
import MLX | ||
import MLXLMCommon | ||
|
||
/// Stores the KV Cache between calls to ``generate`` and maintains | ||
/// the token ids reflected in the cache. | ||
/// | ||
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it | ||
/// to be used within the ``ModelContainer`` context. | ||
/// | ||
/// TODO: cache isolation | ||
public class PromptCache: @unchecked Sendable { | ||
private(set) var cache: [KVCache] | ||
private(set) var tokens: MLXArray | ||
|
||
public init(cache: [KVCache]) { | ||
print("[PromptCache.init]") | ||
self.cache = cache | ||
self.tokens = [] | ||
} | ||
|
||
/// Returns the suffix of the prompt not already in cache, so that only | ||
/// the new part is processed. The tokens of the cache are adjusted here | ||
/// to reflect the new full prompt (i.e. the suffix tokens are added to the | ||
/// cache tokens array), assuming that the prompt suffix will | ||
/// be processed after the call to this function. | ||
/// | ||
/// Trims cache if necessary if part of the cache doesn't match the new | ||
/// prompt. If the model doesn't support trimming and the cache needs to be | ||
/// trimmed, will return nil for the caller to create a new cache. | ||
/// | ||
/// - Returns: | ||
/// - If entirety of cache is in the new prompt: | ||
/// - Return suffix of new prompt, less what is in the cache | ||
/// - If only a portion of the cache is in the new prompt: | ||
/// - Attempt to trim the cache to the common prefix | ||
/// - Return suffix of prompt not in cache | ||
/// - If the cache is not trimmable return nil for the caller | ||
/// to create a new cache. | ||
public func getUncachedSuffix(prompt: MLXArray) -> MLXArray? { | ||
|
||
print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)") | ||
|
||
print("cache[\(self.tokens.size)]: \(self.tokens)") | ||
print("prompt[\(prompt.size)]: \(prompt)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
let comPrefixLength = commonPrefixLength(newPromptTokens: prompt) | ||
print("[getUncachedSuffix] comPrefixLength: \(comPrefixLength)") | ||
|
||
if comPrefixLength == self.tokens.size { | ||
let suffix = prompt[comPrefixLength ..< prompt.size] | ||
print("Concating...") | ||
self.tokens = concatenated([self.tokens, suffix], axis: 0) | ||
return suffix | ||
} else if comPrefixLength < self.tokens.size { | ||
if isTrimmable() { | ||
print("trimming: \(self.tokens.size - comPrefixLength)") | ||
let trimmedLen = self.trim(self.tokens.size - comPrefixLength) | ||
print("trimmed: \(trimmedLen)") | ||
if trimmedLen != self.tokens.size - comPrefixLength { | ||
print("Warning: request trimmed amount and actual trimmed amount are different") | ||
} | ||
self.tokens = self.tokens[0 ..< comPrefixLength] | ||
let suffix = prompt[comPrefixLength ..< prompt.size] | ||
self.tokens = concatenated([self.tokens, suffix], axis: 0) | ||
return suffix | ||
} else { | ||
// Caller must create a new cache | ||
return nil | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
/// - Returns: true if all KV caches are trimmable | ||
public func isTrimmable() -> Bool { | ||
return cache.allSatisfy { $0.isTrimmable() } | ||
} | ||
|
||
/// Trims all KV caches. | ||
/// - Parameters: | ||
/// - n: Amount to trim. | ||
/// - Returns: Amount KV Caches were trimmed (may be less than ``n``). | ||
public func trim(_ n: Int) -> Int { | ||
if !self.isTrimmable() { | ||
return 0 | ||
} | ||
return cache.map { $0.trim(n: n) }.max() ?? 0 | ||
} | ||
|
||
/// Finds the common prefix between the cached prompt and | ||
/// the new prompt. | ||
/// - Parameters: | ||
/// - newPromptTokens: Tokens to compare with cached tokens. | ||
/// - Returns: Length of the common prefix | ||
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int { | ||
return commonPrefixLength(self.tokens, newPromptTokens) | ||
} | ||
|
||
/// Finds the common prefix between ``MLXArray``s. | ||
/// - Parameters: | ||
/// - array1: First array | ||
/// - array2: Second array | ||
/// - Returns: Length of the common prefix | ||
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int { | ||
// TODO: Add test cases | ||
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]") | ||
let minLength = min(array1.size, array2.size) | ||
for i in 0 ..< minLength { | ||
if all(array1[i] .!= array2[i]).item(Bool.self) { | ||
return i | ||
} | ||
} | ||
return minLength | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ class MLXService { | |
/// Includes both language models (LLM) and vision-language models (VLM). | ||
static let availableModels: [LMModel] = [ | ||
LMModel(name: "llama3.2:1b", configuration: LLMRegistry.llama3_2_1B_4bit, type: .llm), | ||
LMModel(name: "llama3.2:3b", configuration: LLMRegistry.llama3_2_3B_4bit, type: .llm), | ||
LMModel(name: "qwen2.5:1.5b", configuration: LLMRegistry.qwen2_5_1_5b, type: .llm), | ||
LMModel(name: "smolLM:135m", configuration: LLMRegistry.smolLM_135M_4bit, type: .llm), | ||
LMModel(name: "qwen3:0.6b", configuration: LLMRegistry.qwen3_0_6b_4bit, type: .llm), | ||
|
@@ -34,6 +35,9 @@ class MLXService { | |
/// Cache to store loaded model containers to avoid reloading. | ||
private let modelCache = NSCache<NSString, ModelContainer>() | ||
|
||
/// Stores a prompt cache for each loaded model | ||
private let promptCache = NSCache<NSString, PromptCache>() | ||
|
||
/// Tracks the current model download progress. | ||
/// Access this property to monitor model download status. | ||
@MainActor | ||
|
@@ -51,6 +55,7 @@ class MLXService { | |
if let container = modelCache.object(forKey: model.name as NSString) { | ||
return container | ||
} else { | ||
print("Model not loaded \(model.name), loading model...") | ||
// Select appropriate factory based on model type | ||
let factory: ModelFactory = | ||
switch model.type { | ||
|
@@ -69,6 +74,9 @@ class MLXService { | |
} | ||
} | ||
|
||
// Clear out the promptCache | ||
promptCache.removeObject(forKey: model.name as NSString) | ||
|
||
// Cache the loaded model for future use | ||
modelCache.setObject(container, forKey: model.name as NSString) | ||
|
||
|
@@ -111,12 +119,51 @@ class MLXService { | |
|
||
// Generate response using the model | ||
return try await modelContainer.perform { (context: ModelContext) in | ||
let lmInput = try await context.processor.prepare(input: userInput) | ||
// Set temperature for response randomness (0.7 provides good balance) | ||
|
||
let fullPrompt = try await context.processor.prepare(input: userInput) | ||
|
||
let parameters = GenerateParameters(temperature: 0.7) | ||
|
||
// TODO: Prompt cache access isn't isolated | ||
// Get the prompt cache and adjust new prompt to remove | ||
// prefix already in cache, trim cache if cache is | ||
// inconsistent with new prompt. | ||
let (cache, lmInput) = getPromptCache( | ||
fullPrompt: fullPrompt, parameters: parameters, context: context, | ||
modelName: model.name) | ||
|
||
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream | ||
return try MLXLMCommon.generate( | ||
input: lmInput, parameters: parameters, context: context) | ||
input: lmInput, parameters: parameters, context: context, cache: cache.cache) | ||
} | ||
} | ||
|
||
func getPromptCache( | ||
fullPrompt: LMInput, parameters: GenerateParameters, context: ModelContext, | ||
modelName: String | ||
) -> (PromptCache, LMInput) { | ||
let cache: PromptCache | ||
if let existingCache = promptCache.object(forKey: modelName as NSString) { | ||
cache = existingCache | ||
} else { | ||
// Create cache if it doesn't exist yet | ||
cache = PromptCache(cache: context.model.newCache(parameters: parameters)) | ||
self.promptCache.setObject(cache, forKey: modelName as NSString) | ||
} | ||
|
||
let lmInput: LMInput | ||
|
||
/// Remove prefix from prompt that is already in cache | ||
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) { | ||
lmInput = LMInput(text: LMInput.Text(tokens: suffix)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought that the KVCache and the prompt should match in size -- that is I thought the prompt should not have the pieces that are already in the KVCache trimmed off. Hrm, perhaps I am confused, here is the LLM prefill code: while y.tokens.size > prefillStepSize {
let input = y[.newAxis, ..<prefillStepSize]
let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
eval(cache)
y = y[prefillStepSize...]
} and I think it matches what you are doing here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at that code, it looks to me like it is just passing the entire prompt (which is whatever was passed to it) through the model, using the existing cache. To be able to do trimming the cache or the model would need to have a record of all of the tokens in the cache up to that point, I'm not sure it does? But if it did, the trimming logic could be moved there. One issue we have at the moment with PromptCache is that it is responsible for recording the tokens represented by the cache. However, because AsyncStream doesn't return tokens we have no way of updating the token list for the cache with the generated response. As a result we always trim the previous response from the new prompt because it doesn't know it is in the cache. Either AsyncStream should return the tokens (which may not be a bad idea anyway), or the cache moved closer to where the tokens are generated and they can be added there. However KVCache doesn't have a record of the tokens (I don't think?) so that is why we need PromptCache with its tokens MLXArray, so if the cache was to be fully managed within TokenIterator we would need to pass it the full PromptCache. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand the chat (turn taking) style use of KVCache a little bit more. I don't think we need to observe tokens directly -- KVCache already represents that. If we want to trim the cache we have a couple of options:
If we put the tokens in the AsyncStream that requires the holder of that to use the streaming detokenizer (not Sendable) and complicates things. Now you may be right -- there may be certain cases where we do need the tokens, but I wonder if that should be handled synchronously inside the TokenIterator? |
||
} else { | ||
// If suffix is nil, the cache is inconsistent with the new prompt | ||
// and the cache doesn't support trimming so create a new one here. | ||
let newCache = PromptCache(cache: context.model.newCache(parameters: parameters)) | ||
self.promptCache.setObject(newCache, forKey: modelName as NSString) | ||
lmInput = fullPrompt | ||
} | ||
|
||
return (cache, lmInput) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -529,14 +529,15 @@ public func generate( | |
/// - input: prepared language model input | ||
/// - parameters: parameters controlling the token generation | ||
/// - context: model context (model and tokenizer) | ||
/// - cache: KV cache from previous output | ||
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop | ||
/// - Returns: the generated output | ||
public func generate( | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think some of this is already in place now. |
||
didGenerate: ([Int]) -> GenerateDisposition | ||
) throws -> GenerateResult { | ||
let iterator = try TokenIterator( | ||
input: input, model: context.model, parameters: parameters) | ||
input: input, model: context.model, cache: cache, parameters: parameters) | ||
return generate( | ||
input: input, context: context, iterator: iterator, didGenerate: didGenerate) | ||
} | ||
|
@@ -626,14 +627,15 @@ public func generate( | |
/// - input: prepared language model input | ||
/// - parameters: parameters controlling the token generation | ||
/// - context: model context (model and tokenizer) | ||
/// - cache: KV cache from previous output | ||
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop | ||
/// - Returns: Information about the generation | ||
public func generate( | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil, | ||
didGenerate: (Int) -> GenerateDisposition | ||
) throws -> GenerateCompletionInfo { | ||
let iterator = try TokenIterator( | ||
input: input, model: context.model, parameters: parameters) | ||
input: input, model: context.model, cache: cache, parameters: parameters) | ||
return generate( | ||
input: input, context: context, iterator: iterator, didGenerate: didGenerate) | ||
} | ||
|
@@ -702,6 +704,7 @@ public func generate( | |
/// - input: The input for the language model. | ||
/// - parameters: The configuration options for token generation. | ||
/// - context: The model context, including the model itself and associated tokenizer. | ||
/// - cache: KV cache from previous output | ||
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`) | ||
/// and completion information (`.info`). | ||
/// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration. | ||
|
@@ -729,10 +732,10 @@ public func generate( | |
/// } | ||
/// ``` | ||
public func generate( | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext | ||
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil | ||
) throws -> AsyncStream<Generation> { | ||
let iterator = try TokenIterator( | ||
input: input, model: context.model, parameters: parameters) | ||
input: input, model: context.model, cache: cache, parameters: parameters) | ||
return generate( | ||
input: input, context: context, iterator: iterator) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part is key -- I think will need a lock and a method like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I've forgotten some of the details but I had roadblocks with each approach. I think there was an issue with ModelContainer.perform() being asynchronous and trying to wrap that with something like withCache.
I might have to leave it to someone with more expertise in Swift concurrency.