Skip to content

Commit 0400695

Browse files
committed
Added actor-based PromptCache and implemented in MLXChatExample
1 parent 40bc2d2 commit 0400695

File tree

3 files changed

+131
-53
lines changed

3 files changed

+131
-53
lines changed

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class MLXService {
3333
]
3434

3535
/// Cache to store loaded model containers to avoid reloading.
36-
private let modelCache = NSCache<NSString, ModelContainer>()
36+
private var modelCache: [String : ModelContainer] = [:]
3737

38-
/// Tracks the ID of the last model used to detect changes for cache invalidation.
39-
private var lastUsedModelId: String?
38+
/// Stores a prompt cache for each loaded model
39+
private var promptCache: [String : PromptCache] = [:]
4040

4141
/// Tracks the current model download progress.
4242
/// Access this property to monitor model download status.
@@ -52,7 +52,7 @@ class MLXService {
5252
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
5353

5454
// Return cached model if available to avoid reloading
55-
if let container = modelCache.object(forKey: model.name as NSString) {
55+
if let container = modelCache[model.name] {
5656
return container
5757
} else {
5858
// Select appropriate factory based on model type
@@ -74,8 +74,7 @@ class MLXService {
7474
}
7575

7676
// Cache the loaded model for future use
77-
modelCache.setObject(container, forKey: model.name as NSString)
78-
77+
modelCache[model.name] = container
7978
return container
8079
}
8180
}
@@ -90,14 +89,6 @@ class MLXService {
9089
// Load or retrieve model from cache
9190
let modelContainer = try await load(model: model)
9291

93-
// Check if the model has changed since last generation
94-
if lastUsedModelId != model.name {
95-
// Clear the cache if the model is different
96-
await modelContainer.clearCache()
97-
print("[MLXService] Model changed, cleared KV Cache.")
98-
lastUsedModelId = model.name
99-
}
100-
10192
// Map app-specific Message type to Chat.Message for model input
10293
let chat = messages.map { message in
10394
let role: Chat.Message.Role =
@@ -123,48 +114,36 @@ class MLXService {
123114

124115
// Generate response using the model
125116
return try await modelContainer.perform { (context: ModelContext) in
126-
// --- Prompt Caching Logic ---
127-
// Only prefill if there are more than just the initial system message and the current turn
128-
// (user + empty assistant = 2). Assumes first message is system.
129-
if messages.count > 3 {
130-
// Prepare history: all messages except the last (empty assistant) one.
131-
// The `processor.prepare` below handles the *full* input including the latest user message.
132-
let historyMessages = Array(chat.dropLast()) // Drop the empty assistant message
133-
let historyUserInput = UserInput(chat: historyMessages)
134-
135-
// Try to get history tokens. Need tokenizer from context.
136-
do {
137-
// Attempt to use the processor first, as it might handle VLM details.
138-
// Note: This runs prepare twice (once for history, once for full), which is suboptimal.
139-
// A better approach might involve direct tokenizer access or a dedicated history tokenization method.
140-
let historyLmInput = try await context.processor.prepare(input: historyUserInput)
141-
let historyTokens = historyLmInput.text.tokens.asArray(Int.self)
142-
143-
// Check if current cache offset matches history length
144-
let currentCacheOffset = context.kvCache?.first?.offset ?? 0 // Assuming single cache for now
145-
146-
if currentCacheOffset != historyTokens.count {
147-
print("[MLXService] Prefilling cache for \(historyTokens.count) history tokens. Current cache offset: \(currentCacheOffset)")
148-
await modelContainer.prefill(promptTokens: historyTokens)
149-
} else {
150-
print("[MLXService] Cache already matches history length (\(currentCacheOffset) tokens). Skipping prefill.")
151-
}
152-
} catch {
153-
// Fallback or error handling if history tokenization fails
154-
print("[MLXService] Warning: Could not prepare history tokens for prefill. Error: \(error). Proceeding without prefill.")
155-
// Ensure cache is clear if we couldn't reliably check/prefill
156-
await modelContainer.clearCache()
157-
}
158-
}
159-
// --- End Caching Logic ---
160117

161-
// Prepare the *full* input (including the latest user message)
162-
let lmInput = try await context.processor.prepare(input: userInput)
163-
// Set temperature for response randomness (0.7 provides good balance)
118+
let fullPrompt = try await context.processor.prepare(input: userInput)
119+
164120
let parameters = GenerateParameters(temperature: 0.7)
165121

122+
let cache: PromptCache
123+
if let existingCache = self.promptCache[model.name] {
124+
cache = existingCache
125+
} else {
126+
// Create cache if it doesn't exist yet
127+
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
128+
promptCache[model.name] = cache
129+
}
130+
131+
let lmInput: LMInput
132+
133+
/// Remove prefix from prompt that is already in cache
134+
if let suffix = await cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
135+
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
136+
} else {
137+
// If suffix is nil, the cache is inconsistent with the new prompt
138+
// and the cache doesn't support trimming so create a new one here.
139+
self.promptCache[model.name] = PromptCache(cache: context.model.newCache(parameters: parameters))
140+
lmInput = fullPrompt
141+
}
142+
143+
// TODO: cache.perform ...
144+
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
166145
return try MLXLMCommon.generate(
167-
input: lmInput, parameters: parameters, context: context)
146+
input: lmInput, parameters: parameters, context: context, cache: await cache.cache)
168147
}
169148
}
170149
}

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ public struct ModelContext {
3333
public var model: any LanguageModel
3434
public var processor: any UserInputProcessor
3535
public var tokenizer: Tokenizer
36-
public var kvCache: [KVCache]? = nil
3736

3837
public init(
3938
configuration: ModelConfiguration, model: any LanguageModel,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//
2+
// PromptCache.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Jolon Faichney on 3/5/2025.
6+
//
7+
8+
import MLX
9+
10+
public actor PromptCache {
11+
public let cache: [KVCache]
12+
public var tokens: MLXArray
13+
14+
public init(cache: [KVCache]) {
15+
print("[PromptCache.init]")
16+
self.cache = cache
17+
self.tokens = []
18+
}
19+
20+
/// Returns the suffix of the prompt not already in cache, so that only
21+
/// the new part is processed. The tokens of the cache are adjusted here
22+
/// to reflect the new full prompt (i.e. the suffix tokens are added to the
23+
/// cache tokens array), assuming that the prompt suffix will
24+
/// be processed after the call to this function.
25+
///
26+
/// Trims cache if necessary if part of the cache doesn't match the new
27+
/// prompt. If the model doesn't support trimming and the cache needs to be
28+
/// trimmed, will return nil for the caller to create a new cache.
29+
///
30+
/// - Returns:
31+
/// - If entirety of cache is in the new prompt:
32+
/// - Return suffix of new prompt, less what is in the cache
33+
/// - If only a portion of the cache is in the new prompt:
34+
/// - Attempt to trim the cache to the common prefix
35+
/// - Return suffix of prompt not in cache
36+
/// - If the cache is not trimmable return nil for the caller
37+
/// to create a new cache.
38+
public func getUncachedSuffix(prompt: MLXArray) async -> MLXArray? {
39+
40+
print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)")
41+
42+
print("cache[\(self.tokens.size)]: \(self.tokens)")
43+
print("prompt[\(prompt.size)]: \(prompt)")
44+
45+
let comPrefixLength = commonPrefixLength(newPromptTokens: prompt)
46+
print("[getUncachedSuffix] comPrefixLength: \(comPrefixLength)")
47+
48+
if comPrefixLength == self.tokens.size {
49+
let optPrompt = prompt[comPrefixLength..<prompt.size]
50+
print("Concating...")
51+
self.tokens = concatenated([self.tokens, optPrompt], axis: 0)
52+
return optPrompt
53+
} else if (comPrefixLength < self.tokens.size) {
54+
if isTrimmable() {
55+
print("trimming: \(self.tokens.size - comPrefixLength)")
56+
let trimmedLen = self.trim(self.tokens.size - comPrefixLength)
57+
print("trimmed: \(trimmedLen)")
58+
if trimmedLen != self.tokens.size - comPrefixLength {
59+
print("Warning: request trimmed amount and actual trimmed amount are different")
60+
}
61+
self.tokens = self.tokens[0..<comPrefixLength]
62+
let optPrompt = prompt[comPrefixLength..<prompt.size]
63+
self.tokens = concatenated([self.tokens, optPrompt], axis: 0)
64+
return optPrompt
65+
} else {
66+
// Caller must create a new cache
67+
return nil
68+
}
69+
}
70+
71+
return nil
72+
}
73+
74+
public func isTrimmable() -> Bool {
75+
return cache.allSatisfy { $0.isTrimmable()}
76+
}
77+
78+
public func trim(_ n: Int) -> Int {
79+
if !self.isTrimmable(){
80+
return 0
81+
}
82+
return cache.map { $0.trim(n: n) }.max() ?? 0
83+
}
84+
85+
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int {
86+
return _commonPrefixLength(self.tokens, newPromptTokens)
87+
}
88+
89+
// TODO: Add tests
90+
public func _commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
91+
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
92+
let minLength = min(array1.size, array2.size)
93+
for i in 0..<minLength {
94+
if all(array1[i] .!= array2[i]).item(Bool.self) {
95+
return i
96+
}
97+
}
98+
return minLength
99+
}
100+
}

0 commit comments

Comments
 (0)