Skip to content

Commit cb7d9d4

Browse files
committed
feat: Implement prompt caching in MLXChatExample
Made prefill async, implemented caching logic in MLXService, and fixed related warnings.
1 parent 1275acc commit cb7d9d4

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import MLX
1010
import MLXLLM
1111
import MLXLMCommon
1212
import MLXVLM
13+
import Tokenizers // Needed for applyChatTemplate
1314

1415
/// A service class that manages machine learning models for text and vision-language tasks.
1516
/// This class handles model loading, caching, and text generation using various LLM and VLM models.
@@ -34,6 +35,9 @@ class MLXService {
3435
/// Cache to store loaded model containers to avoid reloading.
3536
private let modelCache = NSCache<NSString, ModelContainer>()
3637

38+
/// Tracks the ID of the last model used to detect changes for cache invalidation.
39+
private var lastUsedModelId: String?
40+
3741
/// Tracks the current model download progress.
3842
/// Access this property to monitor model download status.
3943
@MainActor
@@ -86,6 +90,14 @@ class MLXService {
8690
// Load or retrieve model from cache
8791
let modelContainer = try await load(model: model)
8892

93+
// Check if the model has changed since last generation
94+
if lastUsedModelId != model.name {
95+
// Clear the cache if the model is different
96+
await modelContainer.clearCache()
97+
print("[MLXService] Model changed, cleared KV Cache.")
98+
lastUsedModelId = model.name
99+
}
100+
89101
// Map app-specific Message type to Chat.Message for model input
90102
let chat = messages.map { message in
91103
let role: Chat.Message.Role =
@@ -111,6 +123,42 @@ class MLXService {
111123

112124
// Generate response using the model
113125
return try await modelContainer.perform { (context: ModelContext) in
126+
// --- Prompt Caching Logic ---
127+
// Only prefill if there are more than just the initial system message and the current turn
128+
// (user + empty assistant = 2). Assumes first message is system.
129+
if messages.count > 3 {
130+
// Prepare history: all messages except the last (empty assistant) one.
131+
// The `processor.prepare` below handles the *full* input including the latest user message.
132+
let historyMessages = Array(chat.dropLast()) // Drop the empty assistant message
133+
let historyUserInput = UserInput(chat: historyMessages)
134+
135+
// Try to get history tokens. Need tokenizer from context.
136+
do {
137+
// Attempt to use the processor first, as it might handle VLM details.
138+
// Note: This runs prepare twice (once for history, once for full), which is suboptimal.
139+
// A better approach might involve direct tokenizer access or a dedicated history tokenization method.
140+
let historyLmInput = try await context.processor.prepare(input: historyUserInput)
141+
let historyTokens = historyLmInput.text.tokens.asArray(Int.self)
142+
143+
// Check if current cache offset matches history length
144+
let currentCacheOffset = context.kvCache?.first?.offset ?? 0 // Assuming single cache for now
145+
146+
if currentCacheOffset != historyTokens.count {
147+
print("[MLXService] Prefilling cache for \(historyTokens.count) history tokens. Current cache offset: \(currentCacheOffset)")
148+
await modelContainer.prefill(promptTokens: historyTokens)
149+
} else {
150+
print("[MLXService] Cache already matches history length (\(currentCacheOffset) tokens). Skipping prefill.")
151+
}
152+
} catch {
153+
// Fallback or error handling if history tokenization fails
154+
print("[MLXService] Warning: Could not prepare history tokens for prefill. Error: \(error). Proceeding without prefill.")
155+
// Ensure cache is clear if we couldn't reliably check/prefill
156+
await modelContainer.clearCache()
157+
}
158+
}
159+
// --- End Caching Logic ---
160+
161+
// Prepare the *full* input (including the latest user message)
114162
let lmInput = try await context.processor.prepare(input: userInput)
115163
// Set temperature for response randomness (0.7 provides good balance)
116164
let parameters = GenerateParameters(temperature: 0.7)

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public actor ModelContainer {
9494
/// - Parameters:
9595
/// - promptTokens: The token IDs to prefill the cache with.
9696
/// - chunkSize: The number of tokens to process in each model evaluation step. Defaults to 512.
97-
public func prefill(promptTokens: [Int], chunkSize: Int = 512) {
97+
public func prefill(promptTokens: [Int], chunkSize: Int = 512) async {
9898
// Ensure we have tokens to process
9999
guard !promptTokens.isEmpty else {
100100
// If the prompt is empty, ensure the cache is cleared
@@ -106,7 +106,7 @@ public actor ModelContainer {
106106
let newCache = context.model.newCache(parameters: nil)
107107

108108
// Convert tokens to MLXArray
109-
var tokensToProcess = MLXArray(promptTokens)
109+
let tokensToProcess = MLXArray(promptTokens)
110110

111111
// Process tokens in chunks
112112
var currentOffset = 0

0 commit comments

Comments
 (0)