@@ -10,6 +10,7 @@ import MLX
10
10
import MLXLLM
11
11
import MLXLMCommon
12
12
import MLXVLM
13
+ import Tokenizers // Needed for applyChatTemplate
13
14
14
15
/// A service class that manages machine learning models for text and vision-language tasks.
15
16
/// This class handles model loading, caching, and text generation using various LLM and VLM models.
@@ -34,6 +35,9 @@ class MLXService {
34
35
/// Cache to store loaded model containers to avoid reloading.
35
36
private let modelCache = NSCache < NSString , ModelContainer > ( )
36
37
38
+ /// Tracks the ID of the last model used to detect changes for cache invalidation.
39
+ private var lastUsedModelId : String ?
40
+
37
41
/// Tracks the current model download progress.
38
42
/// Access this property to monitor model download status.
39
43
@MainActor
@@ -86,6 +90,14 @@ class MLXService {
86
90
// Load or retrieve model from cache
87
91
let modelContainer = try await load ( model: model)
88
92
93
+ // Check if the model has changed since last generation
94
+ if lastUsedModelId != model. name {
95
+ // Clear the cache if the model is different
96
+ await modelContainer. clearCache ( )
97
+ print ( " [MLXService] Model changed, cleared KV Cache. " )
98
+ lastUsedModelId = model. name
99
+ }
100
+
89
101
// Map app-specific Message type to Chat.Message for model input
90
102
let chat = messages. map { message in
91
103
let role : Chat . Message . Role =
@@ -111,6 +123,42 @@ class MLXService {
111
123
112
124
// Generate response using the model
113
125
return try await modelContainer. perform { ( context: ModelContext ) in
126
+ // --- Prompt Caching Logic ---
127
+ // Only prefill if there are more than just the initial system message and the current turn
128
+ // (user + empty assistant = 2). Assumes first message is system.
129
+ if messages. count > 3 {
130
+ // Prepare history: all messages except the last (empty assistant) one.
131
+ // The `processor.prepare` below handles the *full* input including the latest user message.
132
+ let historyMessages = Array ( chat. dropLast ( ) ) // Drop the empty assistant message
133
+ let historyUserInput = UserInput ( chat: historyMessages)
134
+
135
+ // Try to get history tokens. Need tokenizer from context.
136
+ do {
137
+ // Attempt to use the processor first, as it might handle VLM details.
138
+ // Note: This runs prepare twice (once for history, once for full), which is suboptimal.
139
+ // A better approach might involve direct tokenizer access or a dedicated history tokenization method.
140
+ let historyLmInput = try await context. processor. prepare ( input: historyUserInput)
141
+ let historyTokens = historyLmInput. text. tokens. asArray ( Int . self)
142
+
143
+ // Check if current cache offset matches history length
144
+ let currentCacheOffset = context. kvCache? . first? . offset ?? 0 // Assuming single cache for now
145
+
146
+ if currentCacheOffset != historyTokens. count {
147
+ print ( " [MLXService] Prefilling cache for \( historyTokens. count) history tokens. Current cache offset: \( currentCacheOffset) " )
148
+ await modelContainer. prefill ( promptTokens: historyTokens)
149
+ } else {
150
+ print ( " [MLXService] Cache already matches history length ( \( currentCacheOffset) tokens). Skipping prefill. " )
151
+ }
152
+ } catch {
153
+ // Fallback or error handling if history tokenization fails
154
+ print ( " [MLXService] Warning: Could not prepare history tokens for prefill. Error: \( error) . Proceeding without prefill. " )
155
+ // Ensure cache is clear if we couldn't reliably check/prefill
156
+ await modelContainer. clearCache ( )
157
+ }
158
+ }
159
+ // --- End Caching Logic ---
160
+
161
+ // Prepare the *full* input (including the latest user message)
114
162
let lmInput = try await context. processor. prepare ( input: userInput)
115
163
// Set temperature for response randomness (0.7 provides good balance)
116
164
let parameters = GenerateParameters ( temperature: 0.7 )
0 commit comments