Skip to content

Commit 6f48543

Browse files
committed
Tidy + modelCache and promptCache in MLXChatExample changed to NSCache.
1 parent 07867be commit 6f48543

File tree

3 files changed

+59
-44
lines changed

3 files changed

+59
-44
lines changed

Applications/MLXChatExample/Models/PromptCache.swift

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,24 @@ public class PromptCache: @unchecked Sendable {
101101
/// - newPromptTokens: Tokens to compare with cached tokens.
102102
/// - Returns: Length of the common prefix
103103
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int {
104-
return MLX_Studio.commonPrefixLength(self.tokens, newPromptTokens)
104+
return commonPrefixLength(self.tokens, newPromptTokens)
105105
}
106-
}
107-
108-
/// Finds the common prefix between ``MLXArray``s.
109-
/// - Parameters:
110-
/// - array1: First array
111-
/// - array2: Second array
112-
/// - Returns: Length of the common prefix
113-
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
114-
// TODO: Add test cases
115-
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
116-
let minLength = min(array1.size, array2.size)
117-
for i in 0..<minLength {
118-
if all(array1[i] .!= array2[i]).item(Bool.self) {
119-
return i
106+
107+
/// Finds the common prefix between ``MLXArray``s.
108+
/// - Parameters:
109+
/// - array1: First array
110+
/// - array2: Second array
111+
/// - Returns: Length of the common prefix
112+
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
113+
// TODO: Add test cases
114+
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
115+
let minLength = min(array1.size, array2.size)
116+
for i in 0..<minLength {
117+
if all(array1[i] .!= array2[i]).item(Bool.self) {
118+
return i
119+
}
120120
}
121+
return minLength
121122
}
122-
return minLength
123+
123124
}

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ class MLXService {
3232
]
3333

3434
/// Cache to store loaded model containers to avoid reloading.
35-
private var modelCache: [String : ModelContainer] = [:]
35+
private let modelCache = NSCache<NSString, ModelContainer>()
3636

3737
/// Stores a prompt cache for each loaded model
38-
private var promptCache: [String : PromptCache] = [:]
38+
private let promptCache = NSCache<NSString, PromptCache>()
3939

4040
/// Tracks the current model download progress.
4141
/// Access this property to monitor model download status.
@@ -51,9 +51,10 @@ class MLXService {
5151
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
5252

5353
// Return cached model if available to avoid reloading
54-
if let container = modelCache[model.name] {
54+
if let container = modelCache.object(forKey: model.name as NSString) {
5555
return container
5656
} else {
57+
print("Model not loaded \(model.name), loading model...")
5758
// Select appropriate factory based on model type
5859
let factory: ModelFactory =
5960
switch model.type {
@@ -71,9 +72,13 @@ class MLXService {
7172
self.modelDownloadProgress = progress
7273
}
7374
}
74-
75+
76+
// Clear out the promptCache
77+
promptCache.removeObject(forKey: model.name as NSString)
78+
7579
// Cache the loaded model for future use
76-
modelCache[model.name] = container
80+
modelCache.setObject(container, forKey: model.name as NSString)
81+
7782
return container
7883
}
7984
}
@@ -118,32 +123,41 @@ class MLXService {
118123

119124
let parameters = GenerateParameters(temperature: 0.7)
120125

121-
// Get the prompt cache
122-
let cache: PromptCache
123-
if let existingCache = self.promptCache[model.name] {
124-
cache = existingCache
125-
} else {
126-
// Create cache if it doesn't exist yet
127-
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
128-
promptCache[model.name] = cache
129-
}
130-
131-
let lmInput: LMInput
132-
133-
/// Remove prefix from prompt that is already in cache
134-
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
135-
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
136-
} else {
137-
// If suffix is nil, the cache is inconsistent with the new prompt
138-
// and the cache doesn't support trimming so create a new one here.
139-
self.promptCache[model.name] = PromptCache(cache: context.model.newCache(parameters: parameters))
140-
lmInput = fullPrompt
141-
}
126+
// TODO: Prompt cache access isn't isolated
127+
// Get the prompt cache and adjust new prompt to remove
128+
// prefix already in cache, trim cache if cache is
129+
// inconsistent with new prompt.
130+
let (cache, lmInput) = getPromptCache(fullPrompt: fullPrompt, parameters: parameters, context: context, modelName: model.name)
142131

143-
// TODO: cache.perform ...
144132
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
145133
return try MLXLMCommon.generate(
146134
input: lmInput, parameters: parameters, context: context, cache: cache.cache)
147135
}
148136
}
137+
138+
func getPromptCache(fullPrompt: LMInput, parameters: GenerateParameters, context: ModelContext, modelName: String) -> (PromptCache, LMInput) {
139+
let cache: PromptCache
140+
if let existingCache = promptCache.object(forKey: modelName as NSString) {
141+
cache = existingCache
142+
} else {
143+
// Create cache if it doesn't exist yet
144+
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
145+
self.promptCache.setObject(cache, forKey: modelName as NSString)
146+
}
147+
148+
let lmInput: LMInput
149+
150+
/// Remove prefix from prompt that is already in cache
151+
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
152+
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
153+
} else {
154+
// If suffix is nil, the cache is inconsistent with the new prompt
155+
// and the cache doesn't support trimming so create a new one here.
156+
let newCache = PromptCache(cache: context.model.newCache(parameters: parameters))
157+
self.promptCache.setObject(newCache, forKey: modelName as NSString)
158+
lmInput = fullPrompt
159+
}
160+
161+
return (cache, lmInput)
162+
}
149163
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ public func generate(
732732
/// }
733733
/// ```
734734
public func generate(
735-
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]?
735+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil
736736
) throws -> AsyncStream<Generation> {
737737
let iterator = try TokenIterator(
738738
input: input, model: context.model, cache: cache, parameters: parameters)

0 commit comments

Comments
 (0)