diff --git a/Applications/MLXChatExample/Models/PromptCache.swift b/Applications/MLXChatExample/Models/PromptCache.swift new file mode 100644 index 00000000..090b62e5 --- /dev/null +++ b/Applications/MLXChatExample/Models/PromptCache.swift @@ -0,0 +1,124 @@ +// +// PromptCache.swift +// mlx-swift-examples +// +// Created by Jolon Faichney on 3/5/2025. +// + +import MLX +import MLXLMCommon + +/// Stores the KV Cache between calls to ``generate`` and maintains +/// the token ids reflected in the cache. +/// +/// ``PromptCache`` is ``@unchecked Sendable`` which allows it +/// to be used within the ``ModelContainer`` context. +/// +/// TODO: cache isolation +public class PromptCache: @unchecked Sendable { + private(set) var cache: [KVCache] + private(set) var tokens: MLXArray + + public init(cache: [KVCache]) { + print("[PromptCache.init]") + self.cache = cache + self.tokens = [] + } + + /// Returns the suffix of the prompt not already in cache, so that only + /// the new part is processed. The tokens of the cache are adjusted here + /// to reflect the new full prompt (i.e. the suffix tokens are added to the + /// cache tokens array), assuming that the prompt suffix will + /// be processed after the call to this function. + /// + /// Trims cache if necessary if part of the cache doesn't match the new + /// prompt. If the model doesn't support trimming and the cache needs to be + /// trimmed, will return nil for the caller to create a new cache. + /// + /// - Returns: + /// - If entirety of cache is in the new prompt: + /// - Return suffix of new prompt, less what is in the cache + /// - If only a portion of the cache is in the new prompt: + /// - Attempt to trim the cache to the common prefix + /// - Return suffix of prompt not in cache + /// - If the cache is not trimmable return nil for the caller + /// to create a new cache. + public func getUncachedSuffix(prompt: MLXArray) -> MLXArray? { + + print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)") + + print("cache[\(self.tokens.size)]: \(self.tokens)") + print("prompt[\(prompt.size)]: \(prompt)") + + let comPrefixLength = commonPrefixLength(newPromptTokens: prompt) + print("[getUncachedSuffix] comPrefixLength: \(comPrefixLength)") + + if comPrefixLength == self.tokens.size { + let suffix = prompt[comPrefixLength ..< prompt.size] + print("Concating...") + self.tokens = concatenated([self.tokens, suffix], axis: 0) + return suffix + } else if comPrefixLength < self.tokens.size { + if isTrimmable() { + print("trimming: \(self.tokens.size - comPrefixLength)") + let trimmedLen = self.trim(self.tokens.size - comPrefixLength) + print("trimmed: \(trimmedLen)") + if trimmedLen != self.tokens.size - comPrefixLength { + print("Warning: request trimmed amount and actual trimmed amount are different") + } + self.tokens = self.tokens[0 ..< comPrefixLength] + let suffix = prompt[comPrefixLength ..< prompt.size] + self.tokens = concatenated([self.tokens, suffix], axis: 0) + return suffix + } else { + // Caller must create a new cache + return nil + } + } + + return nil + } + + /// - Returns: true if all KV caches are trimmable + public func isTrimmable() -> Bool { + return cache.allSatisfy { $0.isTrimmable() } + } + + /// Trims all KV caches. + /// - Parameters: + /// - n: Amount to trim. + /// - Returns: Amount KV Caches were trimmed (may be less than ``n``). + public func trim(_ n: Int) -> Int { + if !self.isTrimmable() { + return 0 + } + return cache.map { $0.trim(n: n) }.max() ?? 0 + } + + /// Finds the common prefix between the cached prompt and + /// the new prompt. + /// - Parameters: + /// - newPromptTokens: Tokens to compare with cached tokens. + /// - Returns: Length of the common prefix + public func commonPrefixLength(newPromptTokens: MLXArray) -> Int { + return commonPrefixLength(self.tokens, newPromptTokens) + } + + /// Finds the common prefix between ``MLXArray``s. + /// - Parameters: + /// - array1: First array + /// - array2: Second array + /// - Returns: Length of the common prefix + public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int { + // TODO: Add test cases + print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]") + let minLength = min(array1.size, array2.size) + for i in 0 ..< minLength { + if all(array1[i] .!= array2[i]).item(Bool.self) { + return i + } + } + return minLength + } + +} diff --git a/Applications/MLXChatExample/Services/MLXService.swift b/Applications/MLXChatExample/Services/MLXService.swift index 942a2c70..2f4b07a1 100644 --- a/Applications/MLXChatExample/Services/MLXService.swift +++ b/Applications/MLXChatExample/Services/MLXService.swift @@ -19,6 +19,7 @@ class MLXService { /// Includes both language models (LLM) and vision-language models (VLM). static let availableModels: [LMModel] = [ LMModel(name: "llama3.2:1b", configuration: LLMRegistry.llama3_2_1B_4bit, type: .llm), + LMModel(name: "llama3.2:3b", configuration: LLMRegistry.llama3_2_3B_4bit, type: .llm), LMModel(name: "qwen2.5:1.5b", configuration: LLMRegistry.qwen2_5_1_5b, type: .llm), LMModel(name: "smolLM:135m", configuration: LLMRegistry.smolLM_135M_4bit, type: .llm), LMModel(name: "qwen3:0.6b", configuration: LLMRegistry.qwen3_0_6b_4bit, type: .llm), @@ -34,6 +35,9 @@ class MLXService { /// Cache to store loaded model containers to avoid reloading. private let modelCache = NSCache() + /// Stores a prompt cache for each loaded model + private let promptCache = NSCache() + /// Tracks the current model download progress. /// Access this property to monitor model download status. @MainActor @@ -51,6 +55,7 @@ class MLXService { if let container = modelCache.object(forKey: model.name as NSString) { return container } else { + print("Model not loaded \(model.name), loading model...") // Select appropriate factory based on model type let factory: ModelFactory = switch model.type { @@ -69,6 +74,9 @@ class MLXService { } } + // Clear out the promptCache + promptCache.removeObject(forKey: model.name as NSString) + // Cache the loaded model for future use modelCache.setObject(container, forKey: model.name as NSString) @@ -111,12 +119,51 @@ class MLXService { // Generate response using the model return try await modelContainer.perform { (context: ModelContext) in - let lmInput = try await context.processor.prepare(input: userInput) - // Set temperature for response randomness (0.7 provides good balance) + + let fullPrompt = try await context.processor.prepare(input: userInput) + let parameters = GenerateParameters(temperature: 0.7) + // TODO: Prompt cache access isn't isolated + // Get the prompt cache and adjust new prompt to remove + // prefix already in cache, trim cache if cache is + // inconsistent with new prompt. + let (cache, lmInput) = getPromptCache( + fullPrompt: fullPrompt, parameters: parameters, context: context, + modelName: model.name) + + // TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream return try MLXLMCommon.generate( - input: lmInput, parameters: parameters, context: context) + input: lmInput, parameters: parameters, context: context, cache: cache.cache) + } + } + + func getPromptCache( + fullPrompt: LMInput, parameters: GenerateParameters, context: ModelContext, + modelName: String + ) -> (PromptCache, LMInput) { + let cache: PromptCache + if let existingCache = promptCache.object(forKey: modelName as NSString) { + cache = existingCache + } else { + // Create cache if it doesn't exist yet + cache = PromptCache(cache: context.model.newCache(parameters: parameters)) + self.promptCache.setObject(cache, forKey: modelName as NSString) } + + let lmInput: LMInput + + /// Remove prefix from prompt that is already in cache + if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) { + lmInput = LMInput(text: LMInput.Text(tokens: suffix)) + } else { + // If suffix is nil, the cache is inconsistent with the new prompt + // and the cache doesn't support trimming so create a new one here. + let newCache = PromptCache(cache: context.model.newCache(parameters: parameters)) + self.promptCache.setObject(newCache, forKey: modelName as NSString) + lmInput = fullPrompt + } + + return (cache, lmInput) } } diff --git a/Applications/MLXChatExample/ViewModels/ChatViewModel.swift b/Applications/MLXChatExample/ViewModels/ChatViewModel.swift index 06b1530b..1ce583f3 100644 --- a/Applications/MLXChatExample/ViewModels/ChatViewModel.swift +++ b/Applications/MLXChatExample/ViewModels/ChatViewModel.swift @@ -49,6 +49,11 @@ class ChatViewModel { generateCompletionInfo?.tokensPerSecond ?? 0 } + /// Time to generate the first token in seconds + var timeToFirstToken: Double { + generateCompletionInfo?.promptTime ?? 0 + } + /// Progress of the current model download, if any var modelDownloadProgress: Progress? { mlxService.modelDownloadProgress diff --git a/Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift b/Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift index 06012eb3..37acb5be 100644 --- a/Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift +++ b/Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift @@ -29,7 +29,8 @@ struct ChatToolbarView: View { vm.clear([.chat, .meta]) } label: { GenerationInfoView( - tokensPerSecond: vm.tokensPerSecond + tokensPerSecond: vm.tokensPerSecond, + timeToFirstToken: vm.timeToFirstToken ) } diff --git a/Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift b/Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift index f9663f2b..fa5e79cd 100644 --- a/Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift +++ b/Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift @@ -9,12 +9,23 @@ import SwiftUI struct GenerationInfoView: View { let tokensPerSecond: Double + let timeToFirstToken: Double var body: some View { - Text("\(tokensPerSecond, format: .number.precision(.fractionLength(2))) tokens/s") + HStack { + if timeToFirstToken > 0 { + Text(String(format: "TTFT: %.2f s", timeToFirstToken)) + } + if tokensPerSecond > 0 { + Text(String(format: "TPS: %.2f", tokensPerSecond)) + } + } + .lineLimit(1) + .frame(minWidth: 150, alignment: .leading) } } #Preview { - GenerationInfoView(tokensPerSecond: 58.5834) + GenerationInfoView(tokensPerSecond: 58.5834, timeToFirstToken: 1.234) + .padding() } diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 26a27255..28829223 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -529,14 +529,15 @@ public func generate( /// - input: prepared language model input /// - parameters: parameters controlling the token generation /// - context: model context (model and tokenizer) +/// - cache: KV cache from previous output /// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop /// - Returns: the generated output public func generate( - input: LMInput, parameters: GenerateParameters, context: ModelContext, + input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult { let iterator = try TokenIterator( - input: input, model: context.model, parameters: parameters) + input: input, model: context.model, cache: cache, parameters: parameters) return generate( input: input, context: context, iterator: iterator, didGenerate: didGenerate) } @@ -626,14 +627,15 @@ public func generate( /// - input: prepared language model input /// - parameters: parameters controlling the token generation /// - context: model context (model and tokenizer) +/// - cache: KV cache from previous output /// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop /// - Returns: Information about the generation public func generate( - input: LMInput, parameters: GenerateParameters, context: ModelContext, + input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil, didGenerate: (Int) -> GenerateDisposition ) throws -> GenerateCompletionInfo { let iterator = try TokenIterator( - input: input, model: context.model, parameters: parameters) + input: input, model: context.model, cache: cache, parameters: parameters) return generate( input: input, context: context, iterator: iterator, didGenerate: didGenerate) } @@ -702,6 +704,7 @@ public func generate( /// - input: The input for the language model. /// - parameters: The configuration options for token generation. /// - context: The model context, including the model itself and associated tokenizer. +/// - cache: KV cache from previous output /// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`) /// and completion information (`.info`). /// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration. @@ -729,10 +732,10 @@ public func generate( /// } /// ``` public func generate( - input: LMInput, parameters: GenerateParameters, context: ModelContext + input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil ) throws -> AsyncStream { let iterator = try TokenIterator( - input: input, model: context.model, parameters: parameters) + input: input, model: context.model, cache: cache, parameters: parameters) return generate( input: input, context: context, iterator: iterator) } diff --git a/Libraries/MLXLMCommon/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift index 594d80c8..00bc46a3 100644 --- a/Libraries/MLXLMCommon/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -12,6 +12,10 @@ public protocol KVCache: Evaluatable { var offset: Int { get } func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) + + func isTrimmable() -> Bool + + func trim(n: Int) -> Int } func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray { @@ -97,4 +101,13 @@ public class KVCacheSimple: KVCache, Evaluatable { ) } + public func isTrimmable() -> Bool { + return true + } + + public func trim(n: Int) -> Int { + let toTrim = min(self.offset, n) + self.offset -= toTrim + return toTrim + } }