Skip to content

Commit 07867be

Browse files
committed
Moved PromptCache.swift to MLXChatExample and changed it to an @unchecked Sendable class.
1 parent e4b697c commit 07867be

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

Libraries/MLXLMCommon/PromptCache.swift renamed to Applications/MLXChatExample/Models/PromptCache.swift

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@
66
//
77

88
import MLX
9+
import MLXLMCommon
910

10-
public actor PromptCache {
11-
public let cache: [KVCache]
12-
public var tokens: MLXArray
11+
/// Stores the KV Cache between calls to ``generate`` and maintains
12+
/// the token ids reflected in the cache.
13+
///
14+
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it
15+
/// to be used within the ``ModelContainer`` context.
16+
///
17+
/// TODO: cache isolation
18+
public class PromptCache: @unchecked Sendable {
19+
private(set) var cache: [KVCache]
20+
private(set) var tokens: MLXArray
1321

1422
public init(cache: [KVCache]) {
1523
print("[PromptCache.init]")
@@ -35,7 +43,7 @@ public actor PromptCache {
3543
/// - Return suffix of prompt not in cache
3644
/// - If the cache is not trimmable return nil for the caller
3745
/// to create a new cache.
38-
public func getUncachedSuffix(prompt: MLXArray) async -> MLXArray? {
46+
public func getUncachedSuffix(prompt: MLXArray) -> MLXArray? {
3947

4048
print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)")
4149

@@ -71,30 +79,45 @@ public actor PromptCache {
7179
return nil
7280
}
7381

82+
/// - Returns: true if all KV caches are trimmable
7483
public func isTrimmable() -> Bool {
7584
return cache.allSatisfy { $0.isTrimmable()}
7685
}
7786

87+
/// Trims all KV caches.
88+
/// - Parameters:
89+
/// - n: Amount to trim.
90+
/// - Returns: Amount KV Caches were trimmed (may be less than ``n``).
7891
public func trim(_ n: Int) -> Int {
7992
if !self.isTrimmable(){
8093
return 0
8194
}
8295
return cache.map { $0.trim(n: n) }.max() ?? 0
8396
}
8497

98+
/// Finds the common prefix between the cached prompt and
99+
/// the new prompt.
100+
/// - Parameters:
101+
/// - newPromptTokens: Tokens to compare with cached tokens.
102+
/// - Returns: Length of the common prefix
85103
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int {
86-
return _commonPrefixLength(self.tokens, newPromptTokens)
104+
return MLX_Studio.commonPrefixLength(self.tokens, newPromptTokens)
87105
}
106+
}
88107

89-
// TODO: Add tests
90-
public func _commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
91-
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
92-
let minLength = min(array1.size, array2.size)
93-
for i in 0..<minLength {
94-
if all(array1[i] .!= array2[i]).item(Bool.self) {
95-
return i
96-
}
108+
/// Finds the common prefix between ``MLXArray``s.
109+
/// - Parameters:
110+
/// - array1: First array
111+
/// - array2: Second array
112+
/// - Returns: Length of the common prefix
113+
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
114+
// TODO: Add test cases
115+
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
116+
let minLength = min(array1.size, array2.size)
117+
for i in 0..<minLength {
118+
if all(array1[i] .!= array2[i]).item(Bool.self) {
119+
return i
97120
}
98-
return minLength
99121
}
122+
return minLength
100123
}

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import MLX
1010
import MLXLLM
1111
import MLXLMCommon
1212
import MLXVLM
13-
import Tokenizers // Needed for applyChatTemplate
1413

1514
/// A service class that manages machine learning models for text and vision-language tasks.
1615
/// This class handles model loading, caching, and text generation using various LLM and VLM models.
@@ -119,6 +118,7 @@ class MLXService {
119118

120119
let parameters = GenerateParameters(temperature: 0.7)
121120

121+
// Get the prompt cache
122122
let cache: PromptCache
123123
if let existingCache = self.promptCache[model.name] {
124124
cache = existingCache
@@ -131,7 +131,7 @@ class MLXService {
131131
let lmInput: LMInput
132132

133133
/// Remove prefix from prompt that is already in cache
134-
if let suffix = await cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
134+
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
135135
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
136136
} else {
137137
// If suffix is nil, the cache is inconsistent with the new prompt
@@ -143,7 +143,7 @@ class MLXService {
143143
// TODO: cache.perform ...
144144
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
145145
return try MLXLMCommon.generate(
146-
input: lmInput, parameters: parameters, context: context, cache: await cache.cache)
146+
input: lmInput, parameters: parameters, context: context, cache: cache.cache)
147147
}
148148
}
149149
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ public func generate(
529529
/// - input: prepared language model input
530530
/// - parameters: parameters controlling the token generation
531531
/// - context: model context (model and tokenizer)
532+
/// - cache: KV cache from previous output
532533
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
533534
/// - Returns: the generated output
534535
public func generate(
@@ -626,6 +627,7 @@ public func generate(
626627
/// - input: prepared language model input
627628
/// - parameters: parameters controlling the token generation
628629
/// - context: model context (model and tokenizer)
630+
/// - cache: KV cache from previous output
629631
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
630632
/// - Returns: Information about the generation
631633
public func generate(
@@ -702,6 +704,7 @@ public func generate(
702704
/// - input: The input for the language model.
703705
/// - parameters: The configuration options for token generation.
704706
/// - context: The model context, including the model itself and associated tokenizer.
707+
/// - cache: KV cache from previous output
705708
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
706709
/// and completion information (`.info`).
707710
/// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration.

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,4 @@ public actor ModelContainer {
7979
public func update(_ action: @Sendable (inout ModelContext) -> Void) {
8080
action(&context)
8181
}
82-
8382
}

0 commit comments

Comments
 (0)