Skip to content

Commit a675b2a

Browse files
committed
Added actor PromptCache
1 parent 0400695 commit a675b2a

File tree

3 files changed

+21
-67
lines changed

3 files changed

+21
-67
lines changed

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -526,17 +526,17 @@ public func generate(
526526
/// ``generate(input:context:iterator:didGenerate:)``
527527
///
528528
/// - Parameters:
529-
/// - input: language model input
529+
/// - input: prepared language model input
530530
/// - parameters: parameters controlling the token generation
531531
/// - context: model context (model and tokenizer)
532532
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
533533
/// - Returns: the generated output
534534
public func generate(
535-
input: LMInput, parameters: GenerateParameters, context: ModelContext,
535+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
536536
didGenerate: ([Int]) -> GenerateDisposition
537537
) throws -> GenerateResult {
538538
let iterator = try TokenIterator(
539-
input: input, model: context.model, cache: context.kvCache, parameters: parameters)
539+
input: input, model: context.model, cache: cache, parameters: parameters)
540540
return generate(
541541
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
542542
}
@@ -629,11 +629,11 @@ public func generate(
629629
/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop
630630
/// - Returns: Information about the generation
631631
public func generate(
632-
input: LMInput, parameters: GenerateParameters, context: ModelContext,
632+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]? = nil,
633633
didGenerate: (Int) -> GenerateDisposition
634634
) throws -> GenerateCompletionInfo {
635635
let iterator = try TokenIterator(
636-
input: input, model: context.model, cache: context.kvCache, parameters: parameters)
636+
input: input, model: context.model, cache: cache, parameters: parameters)
637637
return generate(
638638
input: input, context: context, iterator: iterator, didGenerate: didGenerate)
639639
}
@@ -729,10 +729,10 @@ public func generate(
729729
/// }
730730
/// ```
731731
public func generate(
732-
input: LMInput, parameters: GenerateParameters, context: ModelContext
732+
input: LMInput, parameters: GenerateParameters, context: ModelContext, cache: [KVCache]?
733733
) throws -> AsyncStream<Generation> {
734734
let iterator = try TokenIterator(
735-
input: input, model: context.model, cache: context.kvCache, parameters: parameters)
735+
input: input, model: context.model, cache: cache, parameters: parameters)
736736
return generate(
737737
input: input, context: context, iterator: iterator)
738738
}

Libraries/MLXLMCommon/KVCache.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ public protocol KVCache: Evaluatable {
1212
var offset: Int { get }
1313

1414
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray)
15+
16+
func isTrimmable() -> Bool
17+
18+
func trim(n: Int) -> Int
1519
}
1620

1721
func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
@@ -96,5 +100,14 @@ public class KVCacheSimple: KVCache, Evaluatable {
96100
self.values![.ellipsis, ..<self.offset, 0...]
97101
)
98102
}
99-
103+
104+
public func isTrimmable() -> Bool {
105+
return true
106+
}
107+
108+
public func trim(n: Int) -> Int {
109+
let toTrim = min(self.offset, n)
110+
self.offset -= toTrim
111+
return toTrim
112+
}
100113
}

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -80,63 +80,4 @@ public actor ModelContainer {
8080
action(&context)
8181
}
8282

83-
/// Clears the Key/Value cache stored within the model context.
84-
public func clearCache() {
85-
context.kvCache = nil
86-
}
87-
88-
/// Prefills the Key/Value cache by running the model's forward pass
89-
/// on the provided tokens.
90-
///
91-
/// This populates the internal cache state, allowing subsequent `generate` calls
92-
/// to start generation immediately after the prefilled tokens without reprocessing them.
93-
///
94-
/// - Parameters:
95-
/// - promptTokens: The token IDs to prefill the cache with.
96-
/// - chunkSize: The number of tokens to process in each model evaluation step. Defaults to 512.
97-
public func prefill(promptTokens: [Int], chunkSize: Int = 512) async {
98-
// Ensure we have tokens to process
99-
guard !promptTokens.isEmpty else {
100-
// If the prompt is empty, ensure the cache is cleared
101-
clearCache()
102-
return
103-
}
104-
105-
// Create a new cache instance
106-
let newCache = context.model.newCache(parameters: nil)
107-
108-
// Convert tokens to MLXArray
109-
let tokensToProcess = MLXArray(promptTokens)
110-
111-
// Process tokens in chunks
112-
var currentOffset = 0
113-
var state: LMOutput.State? = nil // Manage state if the model uses it
114-
115-
while currentOffset < tokensToProcess.size {
116-
let endOffset = min(currentOffset + chunkSize, tokensToProcess.size)
117-
let chunk = tokensToProcess[currentOffset ..< endOffset]
118-
119-
// Create LMInput.Text for the chunk
120-
// Adding a new axis as models typically expect a batch dimension
121-
let inputText = LMInput.Text(tokens: chunk[.newAxis])
122-
123-
// Run the model's forward pass for the chunk
124-
// This implicitly updates the newCache passed to it
125-
let result = context.model(inputText, cache: newCache, state: state)
126-
127-
// Update state if provided by the model
128-
state = result.state
129-
130-
// Move to the next chunk
131-
currentOffset = endOffset
132-
}
133-
134-
// Ensure all computations related to cache population are completed
135-
eval(newCache)
136-
137-
// Store the populated cache in the context
138-
context.kvCache = newCache
139-
}
140-
141-
// TODO: Add trimCache(to offset: Int) method
14283
}

0 commit comments

Comments
 (0)