Skip to content

Feature: prompt caching (Fixes #310) #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions Applications/MLXChatExample/Models/PromptCache.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//
// PromptCache.swift
// mlx-swift-examples
//
// Created by Jolon Faichney on 3/5/2025.
//

import MLX
import MLXLMCommon

/// Stores the KV Cache between calls to ``generate`` and maintains
/// the token ids reflected in the cache.
///
/// ``PromptCache`` is ``@unchecked Sendable`` which allows it
/// to be used within the ``ModelContainer`` context.
///
/// TODO: cache isolation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part is key -- I think will need a lock and a method like:

func withCache<R>(_ block: ([KVCache]) throws -> R) -> rethrows R

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I've forgotten some of the details but I had roadblocks with each approach. I think there was an issue with ModelContainer.perform() being asynchronous and trying to wrap that with something like withCache.

I might have to leave it to someone with more expertise in Swift concurrency.

public class PromptCache: @unchecked Sendable {
private(set) var cache: [KVCache]
private(set) var tokens: MLXArray

public init(cache: [KVCache]) {
print("[PromptCache.init]")
self.cache = cache
self.tokens = []
}

/// Returns the suffix of the prompt not already in cache, so that only
/// the new part is processed. The tokens of the cache are adjusted here
/// to reflect the new full prompt (i.e. the suffix tokens are added to the
/// cache tokens array), assuming that the prompt suffix will
/// be processed after the call to this function.
///
/// Trims cache if necessary if part of the cache doesn't match the new
/// prompt. If the model doesn't support trimming and the cache needs to be
/// trimmed, will return nil for the caller to create a new cache.
///
/// - Returns:
/// - If entirety of cache is in the new prompt:
/// - Return suffix of new prompt, less what is in the cache
/// - If only a portion of the cache is in the new prompt:
/// - Attempt to trim the cache to the common prefix
/// - Return suffix of prompt not in cache
/// - If the cache is not trimmable return nil for the caller
/// to create a new cache.
public func getUncachedSuffix(prompt: MLXArray) -> MLXArray? {

print("[getUncachedSuffix] self.tokens.size = \(self.tokens.size)")

print("cache[\(self.tokens.size)]: \(self.tokens)")
print("prompt[\(prompt.size)]: \(prompt)")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • remove debug printing


let comPrefixLength = commonPrefixLength(newPromptTokens: prompt)
print("[getUncachedSuffix] comPrefixLength: \(comPrefixLength)")

if comPrefixLength == self.tokens.size {
let suffix = prompt[comPrefixLength ..< prompt.size]
print("Concating...")
self.tokens = concatenated([self.tokens, suffix], axis: 0)
return suffix
} else if comPrefixLength < self.tokens.size {
if isTrimmable() {
print("trimming: \(self.tokens.size - comPrefixLength)")
let trimmedLen = self.trim(self.tokens.size - comPrefixLength)
print("trimmed: \(trimmedLen)")
if trimmedLen != self.tokens.size - comPrefixLength {
print("Warning: request trimmed amount and actual trimmed amount are different")
}
self.tokens = self.tokens[0 ..< comPrefixLength]
let suffix = prompt[comPrefixLength ..< prompt.size]
self.tokens = concatenated([self.tokens, suffix], axis: 0)
return suffix
} else {
// Caller must create a new cache
return nil
}
}

return nil
}

/// - Returns: true if all KV caches are trimmable
public func isTrimmable() -> Bool {
return cache.allSatisfy { $0.isTrimmable() }
}

/// Trims all KV caches.
/// - Parameters:
/// - n: Amount to trim.
/// - Returns: Amount KV Caches were trimmed (may be less than ``n``).
public func trim(_ n: Int) -> Int {
if !self.isTrimmable() {
return 0
}
return cache.map { $0.trim(n: n) }.max() ?? 0
}

/// Finds the common prefix between the cached prompt and
/// the new prompt.
/// - Parameters:
/// - newPromptTokens: Tokens to compare with cached tokens.
/// - Returns: Length of the common prefix
public func commonPrefixLength(newPromptTokens: MLXArray) -> Int {
return commonPrefixLength(self.tokens, newPromptTokens)
}

/// Finds the common prefix between ``MLXArray``s.
/// - Parameters:
/// - array1: First array
/// - array2: Second array
/// - Returns: Length of the common prefix
public func commonPrefixLength(_ array1: MLXArray, _ array2: MLXArray) -> Int {
// TODO: Add test cases
print("Calculating common prefix: array1[\(array1.size)] array2[\(array2.size)]")
let minLength = min(array1.size, array2.size)
for i in 0 ..< minLength {
if all(array1[i] .!= array2[i]).item(Bool.self) {
return i
}
}
return minLength
}

}
53 changes: 50 additions & 3 deletions Applications/MLXChatExample/Services/MLXService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class MLXService {
/// Includes both language models (LLM) and vision-language models (VLM).
static let availableModels: [LMModel] = [
LMModel(name: "llama3.2:1b", configuration: LLMRegistry.llama3_2_1B_4bit, type: .llm),
LMModel(name: "llama3.2:3b", configuration: LLMRegistry.llama3_2_3B_4bit, type: .llm),
LMModel(name: "qwen2.5:1.5b", configuration: LLMRegistry.qwen2_5_1_5b, type: .llm),
LMModel(name: "smolLM:135m", configuration: LLMRegistry.smolLM_135M_4bit, type: .llm),
LMModel(name: "qwen3:0.6b", configuration: LLMRegistry.qwen3_0_6b_4bit, type: .llm),
Expand All @@ -34,6 +35,9 @@ class MLXService {
/// Cache to store loaded model containers to avoid reloading.
private let modelCache = NSCache<NSString, ModelContainer>()

/// Stores a prompt cache for each loaded model
private let promptCache = NSCache<NSString, PromptCache>()

/// Tracks the current model download progress.
/// Access this property to monitor model download status.
@MainActor
Expand All @@ -51,6 +55,7 @@ class MLXService {
if let container = modelCache.object(forKey: model.name as NSString) {
return container
} else {
print("Model not loaded \(model.name), loading model...")
// Select appropriate factory based on model type
let factory: ModelFactory =
switch model.type {
Expand All @@ -69,6 +74,9 @@ class MLXService {
}
}

// Clear out the promptCache
promptCache.removeObject(forKey: model.name as NSString)

// Cache the loaded model for future use
modelCache.setObject(container, forKey: model.name as NSString)

Expand Down Expand Up @@ -111,12 +119,51 @@ class MLXService {

// Generate response using the model
return try await modelContainer.perform { (context: ModelContext) in
let lmInput = try await context.processor.prepare(input: userInput)
// Set temperature for response randomness (0.7 provides good balance)

let fullPrompt = try await context.processor.prepare(input: userInput)

let parameters = GenerateParameters(temperature: 0.7)

// TODO: Prompt cache access isn't isolated
// Get the prompt cache and adjust new prompt to remove
// prefix already in cache, trim cache if cache is
// inconsistent with new prompt.
let (cache, lmInput) = getPromptCache(
fullPrompt: fullPrompt, parameters: parameters, context: context,
modelName: model.name)

// TODO: The generated tokens should be added to the prompt cache but not possible with AsyncStream
return try MLXLMCommon.generate(
input: lmInput, parameters: parameters, context: context)
input: lmInput, parameters: parameters, context: context, cache: cache.cache)
}
}

func getPromptCache(
fullPrompt: LMInput, parameters: GenerateParameters, context: ModelContext,
modelName: String
) -> (PromptCache, LMInput) {
let cache: PromptCache
if let existingCache = promptCache.object(forKey: modelName as NSString) {
cache = existingCache
} else {
// Create cache if it doesn't exist yet
cache = PromptCache(cache: context.model.newCache(parameters: parameters))
self.promptCache.setObject(cache, forKey: modelName as NSString)
}

let lmInput: LMInput

/// Remove prefix from prompt that is already in cache
if let suffix = cache.getUncachedSuffix(prompt: fullPrompt.text.tokens) {
lmInput = LMInput(text: LMInput.Text(tokens: suffix))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that the KVCache and the prompt should match in size -- that is I thought the prompt should not have the pieces that are already in the KVCache trimmed off. Hrm, perhaps I am confused, here is the LLM prefill code:

        while y.tokens.size > prefillStepSize {
            let input = y[.newAxis, ..<prefillStepSize]
            let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
            eval(cache)
            y = y[prefillStepSize...]
        }

and I think it matches what you are doing here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at that code, it looks to me like it is just passing the entire prompt (which is whatever was passed to it) through the model, using the existing cache.

To be able to do trimming the cache or the model would need to have a record of all of the tokens in the cache up to that point, I'm not sure it does? But if it did, the trimming logic could be moved there.

One issue we have at the moment with PromptCache is that it is responsible for recording the tokens represented by the cache. However, because AsyncStream doesn't return tokens we have no way of updating the token list for the cache with the generated response. As a result we always trim the previous response from the new prompt because it doesn't know it is in the cache.

Either AsyncStream should return the tokens (which may not be a bad idea anyway), or the cache moved closer to where the tokens are generated and they can be added there. However KVCache doesn't have a record of the tokens (I don't think?) so that is why we need PromptCache with its tokens MLXArray, so if the cache was to be fully managed within TokenIterator we would need to pass it the full PromptCache.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the chat (turn taking) style use of KVCache a little bit more. I don't think we need to observe tokens directly -- KVCache already represents that.

If we want to trim the cache we have a couple of options:

  • we can record the index for particular points in the conversation and roll back to those
  • we can (I think) tokenize a truncated prompt with the back and forth conversation to get a token count and trim to that offset

If we put the tokens in the AsyncStream that requires the holder of that to use the streaming detokenizer (not Sendable) and complicates things.

Now you may be right -- there may be certain cases where we do need the tokens, but I wonder if that should be handled synchronously inside the TokenIterator?

} else {
// If suffix is nil, the cache is inconsistent with the new prompt
// and the cache doesn't support trimming so create a new one here.
let newCache = PromptCache(cache: context.model.newCache(parameters: parameters))
self.promptCache.setObject(newCache, forKey: modelName as NSString)
lmInput = fullPrompt
}

return (cache, lmInput)
}
}
5 changes: 5 additions & 0 deletions Applications/MLXChatExample/ViewModels/ChatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class ChatViewModel {
generateCompletionInfo?.tokensPerSecond ?? 0
}

/// Time to generate the first token in seconds
var timeToFirstToken: Double {
generateCompletionInfo?.promptTime ?? 0
}

/// Progress of the current model download, if any
var modelDownloadProgress: Progress? {
mlxService.modelDownloadProgress
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct ChatToolbarView: View {
vm.clear([.chat, .meta])
} label: {
GenerationInfoView(
tokensPerSecond: vm.tokensPerSecond
tokensPerSecond: vm.tokensPerSecond,
timeToFirstToken: vm.timeToFirstToken
)
}

Expand Down
15 changes: 13 additions & 2 deletions Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ import SwiftUI

struct GenerationInfoView: View {
let tokensPerSecond: Double
let timeToFirstToken: Double

var body: some View {
Text("\(tokensPerSecond, format: .number.precision(.fractionLength(2))) tokens/s")
HStack {
if timeToFirstToken > 0 {
Text(String(format: "TTFT: %.2f s", timeToFirstToken))
}
if tokensPerSecond > 0 {
Text(String(format: "TPS: %.2f", tokensPerSecond))
}
}
.lineLimit(1)
.frame(minWidth: 150, alignment: .leading)
}
}

#Preview {
GenerationInfoView(tokensPerSecond: 58.5834)
GenerationInfoView(tokensPerSecond: 58.5834, timeToFirstToken: 1.234)
.padding()
}
15 changes: 9 additions & 6 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,15 @@ public func generate(
/// - input: prepared language model input
/// - parameters: parameters controlling the token generation
/// - context: model context (model and tokenizer)
/// - cache: KV cache from previous output
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
/// - Returns: the generated output
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext,
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of this is already in place now.

didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult {
let iterator = try TokenIterator(
input: input, model: context.model, parameters: parameters)
input: input, model: context.model, cache: cache, parameters: parameters)
return generate(
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
}
Expand Down Expand Up @@ -626,14 +627,15 @@ public func generate(
/// - input: prepared language model input
/// - parameters: parameters controlling the token generation
/// - context: model context (model and tokenizer)
/// - cache: KV cache from previous output
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
/// - Returns: Information about the generation
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext,
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
didGenerate: (Int) -> GenerateDisposition
) throws -> GenerateCompletionInfo {
let iterator = try TokenIterator(
input: input, model: context.model, parameters: parameters)
input: input, model: context.model, cache: cache, parameters: parameters)
return generate(
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
}
Expand Down Expand Up @@ -702,6 +704,7 @@ public func generate(
/// - input: The input for the language model.
/// - parameters: The configuration options for token generation.
/// - context: The model context, including the model itself and associated tokenizer.
/// - cache: KV cache from previous output
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
/// and completion information (`.info`).
/// - Throws: An error if the `TokenIterator` initialization fails due to invalid input or model configuration.
Expand Down Expand Up @@ -729,10 +732,10 @@ public func generate(
/// }
/// ```
public func generate(
input: LMInput, parameters: GenerateParameters, context: ModelContext
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil
) throws -> AsyncStream<Generation> {
let iterator = try TokenIterator(
input: input, model: context.model, parameters: parameters)
input: input, model: context.model, cache: cache, parameters: parameters)
return generate(
input: input, context: context, iterator: iterator)
}
Expand Down
13 changes: 13 additions & 0 deletions Libraries/MLXLMCommon/KVCache.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ public protocol KVCache: Evaluatable {
var offset: Int { get }

func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray)

func isTrimmable() -> Bool

func trim(n: Int) -> Int
}

func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
Expand Down Expand Up @@ -97,4 +101,13 @@ public class KVCacheSimple: KVCache, Evaluatable {
)
}

public func isTrimmable() -> Bool {
return true
}

public func trim(n: Int) -> Int {
let toTrim = min(self.offset, n)
self.offset -= toTrim
return toTrim
}
}