6
6
//
7
7
8
8
import MLX
9
+ import MLXLMCommon
9
10
10
- public actor PromptCache {
11
- public let cache : [ KVCache ]
12
- public var tokens : MLXArray
11
+ /// Stores the KV Cache between calls to ``generate`` and maintains
12
+ /// the token ids reflected in the cache.
13
+ ///
14
+ /// ``PromptCache`` is ``@unchecked Sendable`` which allows it
15
+ /// to be used within the ``ModelContainer`` context.
16
+ ///
17
+ /// TODO: cache isolation
18
+ public class PromptCache : @unchecked Sendable {
19
+ private( set) var cache : [ KVCache ]
20
+ private( set) var tokens : MLXArray
13
21
14
22
public init ( cache: [ KVCache ] ) {
15
23
print ( " [PromptCache.init] " )
@@ -35,7 +43,7 @@ public actor PromptCache {
35
43
/// - Return suffix of prompt not in cache
36
44
/// - If the cache is not trimmable return nil for the caller
37
45
/// to create a new cache.
38
- public func getUncachedSuffix( prompt: MLXArray ) async -> MLXArray ? {
46
+ public func getUncachedSuffix( prompt: MLXArray ) -> MLXArray ? {
39
47
40
48
print ( " [getUncachedSuffix] self.tokens.size = \( self . tokens. size) " )
41
49
@@ -71,30 +79,45 @@ public actor PromptCache {
71
79
return nil
72
80
}
73
81
82
+ /// - Returns: true if all KV caches are trimmable
74
83
public func isTrimmable( ) -> Bool {
75
84
return cache. allSatisfy { $0. isTrimmable ( ) }
76
85
}
77
86
87
+ /// Trims all KV caches.
88
+ /// - Parameters:
89
+ /// - n: Amount to trim.
90
+ /// - Returns: Amount KV Caches were trimmed (may be less than ``n``).
78
91
public func trim( _ n: Int ) -> Int {
79
92
if !self . isTrimmable ( ) {
80
93
return 0
81
94
}
82
95
return cache. map { $0. trim ( n: n) } . max ( ) ?? 0
83
96
}
84
97
98
+ /// Finds the common prefix between the cached prompt and
99
+ /// the new prompt.
100
+ /// - Parameters:
101
+ /// - newPromptTokens: Tokens to compare with cached tokens.
102
+ /// - Returns: Length of the common prefix
85
103
public func commonPrefixLength( newPromptTokens: MLXArray ) -> Int {
86
- return _commonPrefixLength ( self . tokens, newPromptTokens)
104
+ return MLX_Studio . commonPrefixLength ( self . tokens, newPromptTokens)
87
105
}
106
+ }
88
107
89
- // TODO: Add tests
90
- public func _commonPrefixLength( _ array1: MLXArray , _ array2: MLXArray ) -> Int {
91
- print ( " Calculating common prefix: array1[ \( array1. size) ] array2[ \( array2. size) ] " )
92
- let minLength = min ( array1. size, array2. size)
93
- for i in 0 ..< minLength {
94
- if all ( array1 [ i] .!= array2 [ i] ) . item ( Bool . self) {
95
- return i
96
- }
108
+ /// Finds the common prefix between ``MLXArray``s.
109
+ /// - Parameters:
110
+ /// - array1: First array
111
+ /// - array2: Second array
112
+ /// - Returns: Length of the common prefix
113
+ public func commonPrefixLength( _ array1: MLXArray , _ array2: MLXArray ) -> Int {
114
+ // TODO: Add test cases
115
+ print ( " Calculating common prefix: array1[ \( array1. size) ] array2[ \( array2. size) ] " )
116
+ let minLength = min ( array1. size, array2. size)
117
+ for i in 0 ..< minLength {
118
+ if all ( array1 [ i] .!= array2 [ i] ) . item ( Bool . self) {
119
+ return i
97
120
}
98
- return minLength
99
121
}
122
+ return minLength
100
123
}
0 commit comments