Skip to content

Commit f7da396

Browse files
Add missing KV cache functionality (#334)
* Update dependencies * Demonstrate attention routing in Qwen 3
1 parent 66f7a60 commit f7da396

File tree

7 files changed

+1437
-23
lines changed

7 files changed

+1437
-23
lines changed

Libraries/MLXLLM/Models/Qwen3.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,23 @@ private class Attention: Module {
7676
keys = kNorm(keys.reshaped(B, L, args.kvHeads, -1)).transposed(0, 2, 1, 3)
7777
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
7878

79+
// Apply RoPE positioning
7980
if let cache {
8081
queries = rope(queries, offset: cache.offset)
8182
keys = rope(keys, offset: cache.offset)
82-
(keys, values) = cache.update(keys: keys, values: values)
8383
} else {
8484
queries = rope(queries)
8585
keys = rope(keys)
8686
}
8787

88-
let output = MLXFast.scaledDotProductAttention(
89-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
88+
// Use the automatic attention router that handles both quantized and regular caches
89+
let output = attentionWithCacheUpdate(
90+
queries: queries,
91+
keys: keys,
92+
values: values,
93+
cache: cache,
94+
scale: scale,
95+
mask: mask
9096
)
9197
.transposed(0, 2, 1, 3)
9298
.reshaped(B, L, -1)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import Foundation
2+
import MLX
3+
import MLXFast
4+
5+
/// Attention utilities that match Python mlx-lm's interface
6+
///
7+
/// This provides a single function that automatically routes to quantized or regular
8+
/// attention based on cache type, matching Python's `scaled_dot_product_attention`
9+
10+
/// Automatic attention with cache update
11+
///
12+
/// This function matches Python's `scaled_dot_product_attention` in base.py:
13+
/// - Detects if cache is `QuantizedKVCache` using `isinstance` pattern
14+
/// - Routes to `quantizedScaledDotProductAttention` or `MLXFast.scaledDotProductAttention`
15+
/// - Handles cache updating automatically
16+
/// - Transparent to models - they just call this function
17+
///
18+
/// **Usage in models:**
19+
/// ```swift
20+
/// let output = attentionWithCacheUpdate(
21+
/// queries: queries,
22+
/// keys: keys,
23+
/// values: values,
24+
/// cache: cache,
25+
/// scale: scale,
26+
/// mask: mask
27+
/// )
28+
/// ```
29+
///
30+
/// - Parameters:
31+
/// - queries: Query tensor [B, nHeads, L, D]
32+
/// - keys: Raw key tensor to be cached [B, nKVHeads, L, D]
33+
/// - values: Raw value tensor to be cached [B, nKVHeads, L, D]
34+
/// - cache: Cache instance (any type)
35+
/// - scale: Attention scale factor
36+
/// - mask: Attention mask
37+
/// - Returns: Attention output [B, nHeads, L, D]
38+
public func attentionWithCacheUpdate(
39+
queries: MLXArray,
40+
keys: MLXArray,
41+
values: MLXArray,
42+
cache: KVCache?,
43+
scale: Float,
44+
mask: MLXFast.ScaledDotProductAttentionMaskMode = .none
45+
) -> MLXArray {
46+
guard let cache else {
47+
return MLXFast.scaledDotProductAttention(
48+
queries: queries,
49+
keys: keys,
50+
values: values,
51+
scale: scale,
52+
mask: mask
53+
)
54+
}
55+
if let quantizedKVCache = cache as? QuantizedKVCache {
56+
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
57+
keys: keys, values: values)
58+
return quantizedScaledDotProductAttention(
59+
queries: queries,
60+
quantizedKeys: quantizedKeys,
61+
quantizedValues: quantizedValues,
62+
scale: scale,
63+
mask: mask,
64+
groupSize: quantizedKVCache.groupSize,
65+
bits: quantizedKVCache.bits
66+
)
67+
} else {
68+
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
69+
return MLXFast.scaledDotProductAttention(
70+
queries: queries,
71+
keys: cachedKeys,
72+
values: cachedValues,
73+
scale: scale,
74+
mask: mask
75+
)
76+
}
77+
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ public struct GenerateParameters: Sendable {
5959
/// Maximum tokens to generate
6060
public var maxTokens: Int?
6161

62+
/// Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten.
63+
/// When set, uses ``RotatingKVCache`` instead of ``KVCacheSimple``
64+
public var maxKVSize: Int?
65+
66+
/// Number of bits to use for KV cache quantization. nil implies no cache quantization.
67+
public var kvBits: Int?
68+
69+
/// Group size for KV cache quantization (default: 64)
70+
public var kvGroupSize: Int = 64
71+
72+
/// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0)
73+
public var quantizedKVStart: Int = 0
74+
6275
/// sampling temperature
6376
public var temperature: Float = 0.6
6477

@@ -73,10 +86,18 @@ public struct GenerateParameters: Sendable {
7386

7487
public init(
7588
maxTokens: Int? = nil,
89+
maxKVSize: Int? = nil,
90+
kvBits: Int? = nil,
91+
kvGroupSize: Int = 64,
92+
quantizedKVStart: Int = 0,
7693
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
7794
repetitionContextSize: Int = 20
7895
) {
7996
self.maxTokens = maxTokens
97+
self.maxKVSize = maxKVSize
98+
self.kvBits = kvBits
99+
self.kvGroupSize = kvGroupSize
100+
self.quantizedKVStart = quantizedKVStart
80101
self.temperature = temperature
81102
self.topP = topP
82103
self.repetitionPenalty = repetitionPenalty
@@ -257,6 +278,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
257278
var tokenCount = 0
258279
let maxTokens: Int?
259280

281+
// Cache quantization parameters
282+
let kvBits: Int?
283+
let kvGroupSize: Int
284+
let quantizedKVStart: Int
285+
260286
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
261287
/// replaced with ``init(input:model:cache:parameters:)``.
262288
///
@@ -278,6 +304,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
278304
self.sampler = parameters.sampler()
279305
self.maxTokens = parameters.maxTokens
280306

307+
self.kvBits = parameters.kvBits
308+
self.kvGroupSize = parameters.kvGroupSize
309+
self.quantizedKVStart = parameters.quantizedKVStart
310+
281311
try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize)
282312
}
283313

@@ -305,6 +335,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
305335
self.sampler = parameters.sampler()
306336
self.maxTokens = parameters.maxTokens
307337

338+
self.kvBits = parameters.kvBits
339+
self.kvGroupSize = parameters.kvGroupSize
340+
self.quantizedKVStart = parameters.quantizedKVStart
341+
308342
try prepare(input: input, windowSize: parameters.prefillStepSize)
309343
}
310344

@@ -331,6 +365,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
331365
self.sampler = sampler
332366
self.maxTokens = maxTokens
333367

368+
// No cache quantization for this direct initialization
369+
self.kvBits = nil
370+
self.kvGroupSize = 64
371+
self.quantizedKVStart = 0
372+
334373
try prepare(input: input, windowSize: prefillStepSize)
335374
}
336375

@@ -373,6 +412,14 @@ public struct TokenIterator: Sequence, IteratorProtocol {
373412
previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state)
374413
self.state = result.state
375414

415+
// Apply dynamic cache quantization after each step
416+
maybeQuantizeKVCache(
417+
cache: &cache,
418+
kvBits: kvBits,
419+
kvGroupSize: kvGroupSize,
420+
quantizedKVStart: quantizedKVStart
421+
)
422+
376423
return convertToToken(logits: result.logits)
377424
}
378425

0 commit comments

Comments
 (0)