@@ -32,10 +32,10 @@ class MLXService {
32
32
]
33
33
34
34
/// Cache to store loaded model containers to avoid reloading.
35
- private var modelCache : [ String : ModelContainer ] = [ : ]
35
+ private let modelCache = NSCache < NSString , ModelContainer > ( )
36
36
37
37
/// Stores a prompt cache for each loaded model
38
- private var promptCache : [ String : PromptCache ] = [ : ]
38
+ private let promptCache = NSCache < NSString , PromptCache > ( )
39
39
40
40
/// Tracks the current model download progress.
41
41
/// Access this property to monitor model download status.
@@ -51,9 +51,10 @@ class MLXService {
51
51
MLX . GPU. set ( cacheLimit: 20 * 1024 * 1024 )
52
52
53
53
// Return cached model if available to avoid reloading
54
- if let container = modelCache [ model. name] {
54
+ if let container = modelCache. object ( forKey : model. name as NSString ) {
55
55
return container
56
56
} else {
57
+ print ( " Model not loaded \( model. name) , loading model... " )
57
58
// Select appropriate factory based on model type
58
59
let factory : ModelFactory =
59
60
switch model. type {
@@ -71,9 +72,13 @@ class MLXService {
71
72
self . modelDownloadProgress = progress
72
73
}
73
74
}
74
-
75
+
76
+ // Clear out the promptCache
77
+ promptCache. removeObject ( forKey: model. name as NSString )
78
+
75
79
// Cache the loaded model for future use
76
- modelCache [ model. name] = container
80
+ modelCache. setObject ( container, forKey: model. name as NSString )
81
+
77
82
return container
78
83
}
79
84
}
@@ -118,32 +123,41 @@ class MLXService {
118
123
119
124
let parameters = GenerateParameters ( temperature: 0.7 )
120
125
121
- // Get the prompt cache
122
- let cache : PromptCache
123
- if let existingCache = self . promptCache [ model. name] {
124
- cache = existingCache
125
- } else {
126
- // Create cache if it doesn't exist yet
127
- cache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
128
- promptCache [ model. name] = cache
129
- }
130
-
131
- let lmInput : LMInput
132
-
133
- /// Remove prefix from prompt that is already in cache
134
- if let suffix = cache. getUncachedSuffix ( prompt: fullPrompt. text. tokens) {
135
- lmInput = LMInput ( text: LMInput . Text ( tokens: suffix) )
136
- } else {
137
- // If suffix is nil, the cache is inconsistent with the new prompt
138
- // and the cache doesn't support trimming so create a new one here.
139
- self . promptCache [ model. name] = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
140
- lmInput = fullPrompt
141
- }
126
+ // TODO: Prompt cache access isn't isolated
127
+ // Get the prompt cache and adjust new prompt to remove
128
+ // prefix already in cache, trim cache if cache is
129
+ // inconsistent with new prompt.
130
+ let ( cache, lmInput) = getPromptCache ( fullPrompt: fullPrompt, parameters: parameters, context: context, modelName: model. name)
142
131
143
- // TODO: cache.perform ...
144
132
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
145
133
return try MLXLMCommon . generate (
146
134
input: lmInput, parameters: parameters, context: context, cache: cache. cache)
147
135
}
148
136
}
137
+
138
+ func getPromptCache( fullPrompt: LMInput , parameters: GenerateParameters , context: ModelContext , modelName: String ) -> ( PromptCache , LMInput ) {
139
+ let cache : PromptCache
140
+ if let existingCache = promptCache. object ( forKey: modelName as NSString ) {
141
+ cache = existingCache
142
+ } else {
143
+ // Create cache if it doesn't exist yet
144
+ cache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
145
+ self . promptCache. setObject ( cache, forKey: modelName as NSString )
146
+ }
147
+
148
+ let lmInput : LMInput
149
+
150
+ /// Remove prefix from prompt that is already in cache
151
+ if let suffix = cache. getUncachedSuffix ( prompt: fullPrompt. text. tokens) {
152
+ lmInput = LMInput ( text: LMInput . Text ( tokens: suffix) )
153
+ } else {
154
+ // If suffix is nil, the cache is inconsistent with the new prompt
155
+ // and the cache doesn't support trimming so create a new one here.
156
+ let newCache = PromptCache ( cache: context. model. newCache ( parameters: parameters) )
157
+ self . promptCache. setObject ( newCache, forKey: modelName as NSString )
158
+ lmInput = fullPrompt
159
+ }
160
+
161
+ return ( cache, lmInput)
162
+ }
149
163
}
0 commit comments