Skip to content

Commit 1275acc

Browse files
committed
feat: Implement prefill method in ModelContainer
1 parent 9b0aa5e commit 1275acc

File tree

1 file changed

+52
-14
lines changed

1 file changed

+52
-14
lines changed

Libraries/MLXLMCommon/ModelContainer.swift

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,58 @@ public actor ModelContainer {
8585
context.kvCache = nil
8686
}
8787

88-
/// Prefills the Key/Value cache with the given prompt tokens.
88+
/// Prefills the Key/Value cache by running the model's forward pass
89+
/// on the provided tokens.
8990
///
90-
/// - Parameter promptTokens: The token IDs to prefill the cache with.
91-
/// - Note: This requires specific model support to run the forward pass
92-
/// without full generation and extract the cache state.
93-
/// Implementation is pending further model integration.
94-
public func prefill(promptTokens: [Int]) async throws {
95-
// TODO: Implement prefill logic.
96-
// This will involve:
97-
// 1. Ensuring the model supports cache extraction.
98-
// 2. Running a partial forward pass with promptTokens.
99-
// 3. Storing the resulting KVCache in context.kvCache.
100-
print("Prefill functionality not yet implemented.")
101-
// For now, just clear the cache to avoid using a potentially stale one.
102-
clearCache()
91+
/// This populates the internal cache state, allowing subsequent `generate` calls
92+
/// to start generation immediately after the prefilled tokens without reprocessing them.
93+
///
94+
/// - Parameters:
95+
/// - promptTokens: The token IDs to prefill the cache with.
96+
/// - chunkSize: The number of tokens to process in each model evaluation step. Defaults to 512.
97+
public func prefill(promptTokens: [Int], chunkSize: Int = 512) {
98+
// Ensure we have tokens to process
99+
guard !promptTokens.isEmpty else {
100+
// If the prompt is empty, ensure the cache is cleared
101+
clearCache()
102+
return
103+
}
104+
105+
// Create a new cache instance
106+
let newCache = context.model.newCache(parameters: nil)
107+
108+
// Convert tokens to MLXArray
109+
var tokensToProcess = MLXArray(promptTokens)
110+
111+
// Process tokens in chunks
112+
var currentOffset = 0
113+
var state: LMOutput.State? = nil // Manage state if the model uses it
114+
115+
while currentOffset < tokensToProcess.size {
116+
let endOffset = min(currentOffset + chunkSize, tokensToProcess.size)
117+
let chunk = tokensToProcess[currentOffset ..< endOffset]
118+
119+
// Create LMInput.Text for the chunk
120+
// Adding a new axis as models typically expect a batch dimension
121+
let inputText = LMInput.Text(tokens: chunk[.newAxis])
122+
123+
// Run the model's forward pass for the chunk
124+
// This implicitly updates the newCache passed to it
125+
let result = context.model(inputText, cache: newCache, state: state)
126+
127+
// Update state if provided by the model
128+
state = result.state
129+
130+
// Move to the next chunk
131+
currentOffset = endOffset
132+
}
133+
134+
// Ensure all computations related to cache population are completed
135+
eval(newCache)
136+
137+
// Store the populated cache in the context
138+
context.kvCache = newCache
103139
}
140+
141+
// TODO: Add trimCache(to offset: Int) method
104142
}

0 commit comments

Comments
 (0)