@@ -59,6 +59,19 @@ public struct GenerateParameters: Sendable {
59
59
/// Maximum tokens to generate
60
60
public var maxTokens : Int ?
61
61
62
+ /// Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten.
63
+ /// When set, uses ``RotatingKVCache`` instead of ``KVCacheSimple``
64
+ public var maxKVSize : Int ?
65
+
66
+ /// Number of bits to use for KV cache quantization. nil implies no cache quantization.
67
+ public var kvBits : Int ?
68
+
69
+ /// Group size for KV cache quantization (default: 64)
70
+ public var kvGroupSize : Int = 64
71
+
72
+ /// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0)
73
+ public var quantizedKVStart : Int = 0
74
+
62
75
/// sampling temperature
63
76
public var temperature : Float = 0.6
64
77
@@ -73,10 +86,18 @@ public struct GenerateParameters: Sendable {
73
86
74
87
public init (
75
88
maxTokens: Int ? = nil ,
89
+ maxKVSize: Int ? = nil ,
90
+ kvBits: Int ? = nil ,
91
+ kvGroupSize: Int = 64 ,
92
+ quantizedKVStart: Int = 0 ,
76
93
temperature: Float = 0.6 , topP: Float = 1.0 , repetitionPenalty: Float ? = nil ,
77
94
repetitionContextSize: Int = 20
78
95
) {
79
96
self . maxTokens = maxTokens
97
+ self . maxKVSize = maxKVSize
98
+ self . kvBits = kvBits
99
+ self . kvGroupSize = kvGroupSize
100
+ self . quantizedKVStart = quantizedKVStart
80
101
self . temperature = temperature
81
102
self . topP = topP
82
103
self . repetitionPenalty = repetitionPenalty
@@ -257,6 +278,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
257
278
var tokenCount = 0
258
279
let maxTokens : Int ?
259
280
281
+ // Cache quantization parameters
282
+ let kvBits : Int ?
283
+ let kvGroupSize : Int
284
+ let quantizedKVStart : Int
285
+
260
286
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
261
287
/// replaced with ``init(input:model:cache:parameters:)``.
262
288
///
@@ -278,6 +304,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
278
304
self . sampler = parameters. sampler ( )
279
305
self . maxTokens = parameters. maxTokens
280
306
307
+ self . kvBits = parameters. kvBits
308
+ self . kvGroupSize = parameters. kvGroupSize
309
+ self . quantizedKVStart = parameters. quantizedKVStart
310
+
281
311
try prepare ( input: . init( text: y) , windowSize: parameters. prefillStepSize)
282
312
}
283
313
@@ -305,6 +335,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
305
335
self . sampler = parameters. sampler ( )
306
336
self . maxTokens = parameters. maxTokens
307
337
338
+ self . kvBits = parameters. kvBits
339
+ self . kvGroupSize = parameters. kvGroupSize
340
+ self . quantizedKVStart = parameters. quantizedKVStart
341
+
308
342
try prepare ( input: input, windowSize: parameters. prefillStepSize)
309
343
}
310
344
@@ -331,6 +365,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
331
365
self . sampler = sampler
332
366
self . maxTokens = maxTokens
333
367
368
+ // No cache quantization for this direct initialization
369
+ self . kvBits = nil
370
+ self . kvGroupSize = 64
371
+ self . quantizedKVStart = 0
372
+
334
373
try prepare ( input: input, windowSize: prefillStepSize)
335
374
}
336
375
@@ -373,6 +412,14 @@ public struct TokenIterator: Sequence, IteratorProtocol {
373
412
previous [ text: . newAxis] , cache: cache. isEmpty ? nil : cache, state: state)
374
413
self . state = result. state
375
414
415
+ // Apply dynamic cache quantization after each step
416
+ maybeQuantizeKVCache (
417
+ cache: & cache,
418
+ kvBits: kvBits,
419
+ kvGroupSize: kvGroupSize,
420
+ quantizedKVStart: quantizedKVStart
421
+ )
422
+
376
423
return convertToToken ( logits: result. logits)
377
424
}
378
425
0 commit comments