@@ -85,20 +85,58 @@ public actor ModelContainer {
85
85
context. kvCache = nil
86
86
}
87
87
88
- /// Prefills the Key/Value cache with the given prompt tokens.
88
+ /// Prefills the Key/Value cache by running the model's forward pass
89
+ /// on the provided tokens.
89
90
///
90
- /// - Parameter promptTokens: The token IDs to prefill the cache with.
91
- /// - Note: This requires specific model support to run the forward pass
92
- /// without full generation and extract the cache state.
93
- /// Implementation is pending further model integration.
94
- public func prefill( promptTokens: [ Int ] ) async throws {
95
- // TODO: Implement prefill logic.
96
- // This will involve:
97
- // 1. Ensuring the model supports cache extraction.
98
- // 2. Running a partial forward pass with promptTokens.
99
- // 3. Storing the resulting KVCache in context.kvCache.
100
- print ( " Prefill functionality not yet implemented. " )
101
- // For now, just clear the cache to avoid using a potentially stale one.
102
- clearCache ( )
91
+ /// This populates the internal cache state, allowing subsequent `generate` calls
92
+ /// to start generation immediately after the prefilled tokens without reprocessing them.
93
+ ///
94
+ /// - Parameters:
95
+ /// - promptTokens: The token IDs to prefill the cache with.
96
+ /// - chunkSize: The number of tokens to process in each model evaluation step. Defaults to 512.
97
+ public func prefill( promptTokens: [ Int ] , chunkSize: Int = 512 ) {
98
+ // Ensure we have tokens to process
99
+ guard !promptTokens. isEmpty else {
100
+ // If the prompt is empty, ensure the cache is cleared
101
+ clearCache ( )
102
+ return
103
+ }
104
+
105
+ // Create a new cache instance
106
+ let newCache = context. model. newCache ( parameters: nil )
107
+
108
+ // Convert tokens to MLXArray
109
+ var tokensToProcess = MLXArray ( promptTokens)
110
+
111
+ // Process tokens in chunks
112
+ var currentOffset = 0
113
+ var state : LMOutput . State ? = nil // Manage state if the model uses it
114
+
115
+ while currentOffset < tokensToProcess. size {
116
+ let endOffset = min ( currentOffset + chunkSize, tokensToProcess. size)
117
+ let chunk = tokensToProcess [ currentOffset ..< endOffset]
118
+
119
+ // Create LMInput.Text for the chunk
120
+ // Adding a new axis as models typically expect a batch dimension
121
+ let inputText = LMInput . Text ( tokens: chunk [ . newAxis] )
122
+
123
+ // Run the model's forward pass for the chunk
124
+ // This implicitly updates the newCache passed to it
125
+ let result = context. model ( inputText, cache: newCache, state: state)
126
+
127
+ // Update state if provided by the model
128
+ state = result. state
129
+
130
+ // Move to the next chunk
131
+ currentOffset = endOffset
132
+ }
133
+
134
+ // Ensure all computations related to cache population are completed
135
+ eval ( newCache)
136
+
137
+ // Store the populated cache in the context
138
+ context. kvCache = newCache
103
139
}
140
+
141
+ // TODO: Add trimCache(to offset: Int) method
104
142
}
0 commit comments