@@ -33,10 +33,10 @@ class MLXService {
33
33
]
34
34
35
35
/// Cache to store loaded model containers to avoid reloading.
36
- private let modelCache = NSCache < NSString , ModelContainer > ( )
36
+ private var modelCache : [ String : ModelContainer ] = [ : ]
37
37
38
- /// Tracks the ID of the last model used to detect changes for cache invalidation.
39
- private var lastUsedModelId : String ?
38
+ /// Stores a prompt cache for each loaded model
39
+ private var promptCache : [ String : PromptCache ] = [ : ]
40
40
41
41
/// Tracks the current model download progress.
42
42
/// Access this property to monitor model download status.
@@ -52,7 +52,7 @@ class MLXService {
52
52
MLX . GPU. set ( cacheLimit: 20 * 1024 * 1024 )
53
53
54
54
// Return cached model if available to avoid reloading
55
- if let container = modelCache. object ( forKey : model. name as NSString ) {
55
+ if let container = modelCache [ model. name] {
56
56
return container
57
57
} else {
58
58
// Select appropriate factory based on model type
@@ -74,8 +74,7 @@ class MLXService {
74
74
}
75
75
76
76
// Cache the loaded model for future use
77
- modelCache. setObject ( container, forKey: model. name as NSString )
78
-
77
+ modelCache [ model. name] = container
79
78
return container
80
79
}
81
80
}
@@ -90,14 +89,6 @@ class MLXService {
90
89
// Load or retrieve model from cache
91
90
let modelContainer = try await load ( model: model)
92
91
93
- // Check if the model has changed since last generation
94
- if lastUsedModelId != model. name {
95
- // Clear the cache if the model is different
96
- await modelContainer. clearCache ( )
97
- print ( " [MLXService] Model changed, cleared KV Cache. " )
98
- lastUsedModelId = model. name
99
- }
100
-
101
92
// Map app-specific Message type to Chat.Message for model input
102
93
let chat = messages. map { message in
103
94
let role : Chat . Message . Role =
@@ -123,48 +114,36 @@ class MLXService {
123
114
124
115
// Generate response using the model
125
116
return try await modelContainer. perform { ( context: ModelContext ) in
126
- // --- Prompt Caching Logic ---
127
- // Only prefill if there are more than just the initial system message and the current turn
128
- // (user + empty assistant = 2). Assumes first message is system.
129
- if messages. count > 3 {
130
- // Prepare history: all messages except the last (empty assistant) one.
131
- // The `processor.prepare` below handles the *full* input including the latest user message.
132
- let historyMessages = Array ( chat. dropLast ( ) ) // Drop the empty assistant message
133
- let historyUserInput = UserInput ( chat: historyMessages)
134
-
135
- // Try to get history tokens. Need tokenizer from context.
136
- do {
137
- // Attempt to use the processor first, as it might handle VLM details.
138
- // Note: This runs prepare twice (once for history, once for full), which is suboptimal.
139
- // A better approach might involve direct tokenizer access or a dedicated history tokenization method.
140
- let historyLmInput = try await context. processor. prepare ( input: historyUserInput)
141
- let historyTokens = historyLmInput. text. tokens. asArray ( Int . self)
142
-
143
- // Check if current cache offset matches history length
144
- let currentCacheOffset = context. kvCache? . first? . offset ?? 0 // Assuming single cache for now
145
-
146
- if currentCacheOffset != historyTokens. count {
147
- print ( " [MLXService] Prefilling cache for \( historyTokens. count) history tokens. Current cache offset: \( currentCacheOffset) " )
148
- await modelContainer. prefill ( promptTokens: historyTokens)
149
- } else {
150
- print ( " [MLXService] Cache already matches history length ( \( currentCacheOffset) tokens). Skipping prefill. " )
151
- }
152
- } catch {
153
- // Fallback or error handling if history tokenization fails
154
- print ( " [MLXService] Warning: Could not prepare history tokens for prefill. Error: \( error) . Proceeding without prefill. " )
155
- // Ensure cache is clear if we couldn't reliably check/prefill
156
- await modelContainer. clearCache ( )
157
- }
158
- }
159
- // --- End Caching Logic ---
160
117
161
- // Prepare the *full* input (including the latest user message)
162
- let lmInput = try await context. processor. prepare ( input: userInput)
163
- // Set temperature for response randomness (0.7 provides good balance)
118
+ let fullPrompt = try await context. processor. prepare ( input: userInput)
119
+
164
120
let parameters = GenerateParameters ( temperature: 0.7 )
165
121
122
+ let cache : PromptCache
123
+ if let existingCache = self . promptCache [ model. name] {
124
+ cache = existingCache
125
+ } else {
126
+ // Create cache if it doesn't exist yet
127
+ cache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
128
+ promptCache [ model. name] = cache
129
+ }
130
+
131
+ let lmInput : LMInput
132
+
133
+ /// Remove prefix from prompt that is already in cache
134
+ if let suffix = await cache. getUncachedSuffix ( prompt: fullPrompt. text. tokens) {
135
+ lmInput = LMInput ( text: LMInput . Text ( tokens: suffix) )
136
+ } else {
137
+ // If suffix is nil, the cache is inconsistent with the new prompt
138
+ // and the cache doesn't support trimming so create a new one here.
139
+ self . promptCache [ model. name] = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
140
+ lmInput = fullPrompt
141
+ }
142
+
143
+ // TODO: cache.perform ...
144
+ // TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
166
145
return try MLXLMCommon . generate (
167
- input: lmInput, parameters: parameters, context: context)
146
+ input: lmInput, parameters: parameters, context: context, cache : await cache . cache )
168
147
}
169
148
}
170
149
}
0 commit comments