Skip to content

Commit e3645dc

Browse files
committed
Fix #310: Add support for prompt caching and example in MLXChatExample.
1 parent f17047c commit e3645dc

File tree

7 files changed

+216
-12
lines changed

7 files changed

+216
-12
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//
2+
// PromptCache.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Jolon Faichney on 3/5/2025.
6+
//
7+
8+
import MLX
9+
import MLXLMCommon
10+
11+
/// Stores the KV Cache between calls to ``generate`` and maintains
12+
/// the token ids reflected in the cache.
13+
///
14+
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it
15+
/// to be used within the ``ModelContainer`` context.
16+
///
17+
/// TODO: cache isolation
18+
public class PromptCache: @unchecked Sendable {
19+
private(set) var cache: [KVCache]
20+
private(set) var tokens: MLXArray
21+
22+
public init(cache: [KVCache]) {
23+
print("[PromptCache.init]")
24+
self.cache = cache
25+
self.tokens = []
26+
}
27+
28+
/// Returns the suffix of the prompt not already in cache, so that only
29+
/// the new part is processed. The tokens of the cache are adjusted here
30+
/// to reflect the new full prompt (i.e. the suffix tokens are added to the
31+
/// cache tokens array), assuming that the prompt suffix will
32+
/// be processed after the call to this function.
33+
///
34+
/// Trims cache if necessary if part of the cache doesn't match the new
35+
/// prompt. If the model doesn't support trimming and the cache needs to be
36+
/// trimmed, will return nil for the caller to create a new cache.
37+
///
38+
/// - Returns:
39+
/// - If entirety of cache is in the new prompt:
40+
/// - Return suffix of new prompt, less what is in the cache
41+
/// - If only a portion of the cache is in the new prompt:
42+
/// - Attempt to trim the cache to the common prefix
43+
/// - Return suffix of prompt not in cache
44+
/// - If the cache is not trimmable return nil for the caller
45+
/// to create a new cache.
46+
public func getUncachedSuffix(prompt: MLXArray) -> MLXArray? {
47+
48+
print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)")
49+
50+
print("cache[\(self.tokens.size)]: \(self.tokens)")
51+
print("prompt[\(prompt.size)]: \(prompt)")
52+
53+
let comPrefixLength = commonPrefixLength(newPromptTokens: prompt)
54+
print("[getUncachedSuffix] comPrefixLength: \(comPrefixLength)")
55+
56+
if comPrefixLength == self.tokens.size {
57+
let suffix = prompt[comPrefixLength ..< prompt.size]
58+
print("Concating...")
59+
self.tokens = concatenated([self.tokens, suffix], axis: 0)
60+
return suffix
61+
} else if comPrefixLength < self.tokens.size {
62+
if isTrimmable() {
63+
print("trimming: \(self.tokens.size - comPrefixLength)")
64+
let trimmedLen = self.trim(self.tokens.size - comPrefixLength)
65+
print("trimmed: \(trimmedLen)")
66+
if trimmedLen != self.tokens.size - comPrefixLength {
67+
print("Warning: request trimmed amount and actual trimmed amount are different")
68+
}
69+
self.tokens = self.tokens[0 ..< comPrefixLength]
70+
let suffix = prompt[comPrefixLength ..< prompt.size]
71+
self.tokens = concatenated([self.tokens, suffix], axis: 0)
72+
return suffix
73+
} else {
74+
// Caller must create a new cache
75+
return nil
76+
}
77+
}
78+
79+
return nil
80+
}
81+
82+
/// - Returns: true if all KV caches are trimmable
83+
public func isTrimmable() -> Bool {
84+
return cache.allSatisfy { $0.isTrimmable() }
85+
}
86+
87+
/// Trims all KV caches.
88+
/// - Parameters:
89+
/// - n: Amount to trim.
90+
/// - Returns: Amount KV Caches were trimmed (may be less than ``n``).
91+
public func trim(_ n: Int) -> Int {
92+
if !self.isTrimmable() {
93+
return 0
94+
}
95+
return cache.map { $0.trim(n: n) }.max() ?? 0
96+
}
97+
98+
/// Finds the common prefix between the cached prompt and
99+
/// the new prompt.
100+
/// - Parameters:
101+
/// - newPromptTokens: Tokens to compare with cached tokens.
102+
/// - Returns: Length of the common prefix
103+
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int {
104+
return commonPrefixLength(self.tokens, newPromptTokens)
105+
}
106+
107+
/// Finds the common prefix between ``MLXArray``s.
108+
/// - Parameters:
109+
/// - array1: First array
110+
/// - array2: Second array
111+
/// - Returns: Length of the common prefix
112+
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
113+
// TODO: Add test cases
114+
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
115+
let minLength = min(array1.size, array2.size)
116+
for i in 0 ..< minLength {
117+
if all(array1[i] .!= array2[i]).item(Bool.self) {
118+
return i
119+
}
120+
}
121+
return minLength
122+
}
123+
124+
}

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class MLXService {
1919
/// Includes both language models (LLM) and vision-language models (VLM).
2020
static let availableModels: [LMModel] = [
2121
LMModel(name: "llama3.2:1b", configuration: LLMRegistry.llama3_2_1B_4bit, type: .llm),
22+
LMModel(name: "llama3.2:3b", configuration: LLMRegistry.llama3_2_3B_4bit, type: .llm),
2223
LMModel(name: "qwen2.5:1.5b", configuration: LLMRegistry.qwen2_5_1_5b, type: .llm),
2324
LMModel(name: "smolLM:135m", configuration: LLMRegistry.smolLM_135M_4bit, type: .llm),
2425
LMModel(name: "qwen3:0.6b", configuration: LLMRegistry.qwen3_0_6b_4bit, type: .llm),
@@ -34,6 +35,9 @@ class MLXService {
3435
/// Cache to store loaded model containers to avoid reloading.
3536
private let modelCache = NSCache<NSString, ModelContainer>()
3637

38+
/// Stores a prompt cache for each loaded model
39+
private let promptCache = NSCache<NSString, PromptCache>()
40+
3741
/// Tracks the current model download progress.
3842
/// Access this property to monitor model download status.
3943
@MainActor
@@ -51,6 +55,7 @@ class MLXService {
5155
if let container = modelCache.object(forKey: model.name as NSString) {
5256
return container
5357
} else {
58+
print("Model not loaded \(model.name), loading model...")
5459
// Select appropriate factory based on model type
5560
let factory: ModelFactory =
5661
switch model.type {
@@ -69,6 +74,9 @@ class MLXService {
6974
}
7075
}
7176

77+
// Clear out the promptCache
78+
promptCache.removeObject(forKey: model.name as NSString)
79+
7280
// Cache the loaded model for future use
7381
modelCache.setObject(container, forKey: model.name as NSString)
7482

@@ -111,12 +119,51 @@ class MLXService {
111119

112120
// Generate response using the model
113121
return try await modelContainer.perform { (context: ModelContext) in
114-
let lmInput = try await context.processor.prepare(input: userInput)
115-
// Set temperature for response randomness (0.7 provides good balance)
122+
123+
let fullPrompt = try await context.processor.prepare(input: userInput)
124+
116125
let parameters = GenerateParameters(temperature: 0.7)
117126

127+
// TODO: Prompt cache access isn't isolated
128+
// Get the prompt cache and adjust new prompt to remove
129+
// prefix already in cache, trim cache if cache is
130+
// inconsistent with new prompt.
131+
let (cache, lmInput) = getPromptCache(
132+
fullPrompt: fullPrompt, parameters: parameters, context: context,
133+
modelName: model.name)
134+
135+
// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
118136
return try MLXLMCommon.generate(
119-
input: lmInput, parameters: parameters, context: context)
137+
input: lmInput, parameters: parameters, context: context, cache: cache.cache)
138+
}
139+
}
140+
141+
func getPromptCache(
142+
fullPrompt: LMInput, parameters: GenerateParameters, context: ModelContext,
143+
modelName: String
144+
) -> (PromptCache, LMInput) {
145+
let cache: PromptCache
146+
if let existingCache = promptCache.object(forKey: modelName as NSString) {
147+
cache = existingCache
148+
} else {
149+
// Create cache if it doesn't exist yet
150+
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
151+
self.promptCache.setObject(cache, forKey: modelName as NSString)
120152
}
153+
154+
let lmInput: LMInput
155+
156+
/// Remove prefix from prompt that is already in cache
157+
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
158+
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
159+
} else {
160+
// If suffix is nil, the cache is inconsistent with the new prompt
161+
// and the cache doesn't support trimming so create a new one here.
162+
let newCache = PromptCache(cache: context.model.newCache(parameters: parameters))
163+
self.promptCache.setObject(newCache, forKey: modelName as NSString)
164+
lmInput = fullPrompt
165+
}
166+
167+
return (cache, lmInput)
121168
}
122169
}

Applications/MLXChatExample/ViewModels/ChatViewModel.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class ChatViewModel {
4949
generateCompletionInfo?.tokensPerSecond ?? 0
5050
}
5151

52+
/// Time to generate the first token in seconds
53+
var timeToFirstToken: Double {
54+
generateCompletionInfo?.promptTime ?? 0
55+
}
56+
5257
/// Progress of the current model download, if any
5358
var modelDownloadProgress: Progress? {
5459
mlxService.modelDownloadProgress

Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ struct ChatToolbarView: View {
2929
vm.clear([.chat, .meta])
3030
} label: {
3131
GenerationInfoView(
32-
tokensPerSecond: vm.tokensPerSecond
32+
tokensPerSecond: vm.tokensPerSecond,
33+
timeToFirstToken: vm.timeToFirstToken
3334
)
3435
}
3536

Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@ import SwiftUI
99

1010
struct GenerationInfoView: View {
1111
let tokensPerSecond: Double
12+
let timeToFirstToken: Double
1213

1314
var body: some View {
14-
Text("\(tokensPerSecond, format: .number.precision(.fractionLength(2))) tokens/s")
15+
HStack {
16+
if timeToFirstToken > 0 {
17+
Text(String(format: "TTFT: %.2f s", timeToFirstToken))
18+
}
19+
if tokensPerSecond > 0 {
20+
Text(String(format: "TPS: %.2f", tokensPerSecond))
21+
}
22+
}
23+
.lineLimit(1)
24+
.frame(minWidth: 150, alignment: .leading)
1525
}
1626
}
1727

1828
#Preview {
19-
GenerationInfoView(tokensPerSecond: 58.5834)
29+
GenerationInfoView(tokensPerSecond: 58.5834, timeToFirstToken: 1.234)
30+
.padding()
2031
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,15 @@ public func generate(
529529
/// - input: prepared language model input
530530
/// - parameters: parameters controlling the token generation
531531
/// - context: model context (model and tokenizer)
532+
/// - cache: KV cache from previous output
532533
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
533534
/// - Returns: the generated output
534535
public func generate(
535-
input: LMInput, parameters: GenerateParameters, context: ModelContext,
536+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
536537
didGenerate: ([Int]) -> GenerateDisposition
537538
) throws -> GenerateResult {
538539
let iterator = try TokenIterator(
539-
input: input, model: context.model, parameters: parameters)
540+
input: input, model: context.model, cache: cache, parameters: parameters)
540541
return generate(
541542
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
542543
}
@@ -626,14 +627,15 @@ public func generate(
626627
/// - input: prepared language model input
627628
/// - parameters: parameters controlling the token generation
628629
/// - context: model context (model and tokenizer)
630+
/// - cache: KV cache from previous output
629631
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
630632
/// - Returns: Information about the generation
631633
public func generate(
632-
input: LMInput, parameters: GenerateParameters, context: ModelContext,
634+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
633635
didGenerate: (Int) -> GenerateDisposition
634636
) throws -> GenerateCompletionInfo {
635637
let iterator = try TokenIterator(
636-
input: input, model: context.model, parameters: parameters)
638+
input: input, model: context.model, cache: cache, parameters: parameters)
637639
return generate(
638640
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
639641
}
@@ -702,6 +704,7 @@ public func generate(
702704
/// - input: The input for the language model.
703705
/// - parameters: The configuration options for token generation.
704706
/// - context: The model context, including the model itself and associated tokenizer.
707+
/// - cache: KV cache from previous output
705708
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
706709
/// and completion information (`.info`).
707710
/// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration.
@@ -729,10 +732,10 @@ public func generate(
729732
/// }
730733
/// ```
731734
public func generate(
732-
input: LMInput, parameters: GenerateParameters, context: ModelContext
735+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil
733736
) throws -> AsyncStream<Generation> {
734737
let iterator = try TokenIterator(
735-
input: input, model: context.model, parameters: parameters)
738+
input: input, model: context.model, cache: cache, parameters: parameters)
736739
return generate(
737740
input: input, context: context, iterator: iterator)
738741
}

Libraries/MLXLMCommon/KVCache.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ public protocol KVCache: Evaluatable {
1212
var offset: Int { get }
1313

1414
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray)
15+
16+
func isTrimmable() -> Bool
17+
18+
func trim(n: Int) -> Int
1519
}
1620

1721
func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
@@ -97,4 +101,13 @@ public class KVCacheSimple: KVCache, Evaluatable {
97101
)
98102
}
99103

104+
public func isTrimmable() -> Bool {
105+
return true
106+
}
107+
108+
public func trim(n: Int) -> Int {
109+
let toTrim = min(self.offset, n)
110+
self.offset -= toTrim
111+
return toTrim
112+
}
100113
}

0 commit comments

Comments
 (0)