Skip to content

Commit 2fcd726

Browse files
Update models for KV cache routing (#338)
1 parent 7482e98 commit 2fcd726

21 files changed

+190
-93
lines changed

Libraries/MLXLLM/Models/Cohere.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,18 @@ private class Attention: Module {
5353
if let cache {
5454
queries = rope(queries, offset: cache.offset)
5555
keys = rope(keys, offset: cache.offset)
56-
(keys, values) = cache.update(keys: keys, values: values)
5756
} else {
5857
queries = rope(queries)
5958
keys = rope(keys)
6059
}
6160

62-
let output = MLXFast.scaledDotProductAttention(
63-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
61+
let output = attentionWithCacheUpdate(
62+
queries: queries,
63+
keys: keys,
64+
values: values,
65+
cache: cache,
66+
scale: scale,
67+
mask: mask
6468
)
6569
.transposed(0, 2, 1, 3)
6670
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/GLM4.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,17 @@ private class Attention: Module {
5656
if let cache {
5757
queries = rope(queries, offset: cache.offset)
5858
keys = rope(keys, offset: cache.offset)
59-
(keys, values) = cache.update(
60-
keys: keys, values: values)
6159
} else {
6260
queries = rope(queries)
6361
keys = rope(keys)
6462
}
6563

66-
let output = MLXFast.scaledDotProductAttention(
67-
queries: queries, keys: keys, values: values, scale: scale,
64+
let output = attentionWithCacheUpdate(
65+
queries: queries,
66+
keys: keys,
67+
values: values,
68+
cache: cache,
69+
scale: scale,
6870
mask: mask
6971
)
7072
.transposed(0, 2, 1, 3)

Libraries/MLXLLM/Models/Gemma.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,18 @@ private class Attention: Module {
7272
if let cache {
7373
queries = rope(queries, offset: cache.offset)
7474
keys = rope(keys, offset: cache.offset)
75-
(keys, values) = cache.update(keys: keys, values: values)
7675
} else {
7776
queries = rope(queries)
7877
keys = rope(keys)
7978
}
8079

81-
let output = MLXFast.scaledDotProductAttention(
82-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
80+
let output = attentionWithCacheUpdate(
81+
queries: queries,
82+
keys: keys,
83+
values: values,
84+
cache: cache,
85+
scale: scale,
86+
mask: mask
8387
)
8488
.transposed(0, 2, 1, 3)
8589
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ private class Attention: Module {
163163
if let cache {
164164
queries = rope(queries, offset: cache.offset)
165165
keys = rope(keys, offset: cache.offset)
166-
(keys, values) = cache.update(keys: keys, values: values)
167166
} else {
168167
queries = rope(queries)
169168
keys = rope(keys)
@@ -179,14 +178,16 @@ private class Attention: Module {
179178
}
180179
}
181180

182-
var output = MLXFast.scaledDotProductAttention(
181+
let output = attentionWithCacheUpdate(
183182
queries: queries,
184183
keys: keys,
185184
values: values,
185+
cache: cache,
186186
scale: scale,
187187
mask: finalMask
188188
)
189-
output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1)
189+
.transposed(0, 2, 1, 3)
190+
.reshaped(B, L, -1)
190191
return outputProj(output)
191192
}
192193
}

Libraries/MLXLLM/Models/Granite.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ private class Attention: Module {
7171
if let cache {
7272
queries = rope(queries, offset: cache.offset)
7373
keys = rope(keys, offset: cache.offset)
74-
(keys, values) = cache.update(keys: keys, values: values)
7574
} else {
7675
queries = rope(queries)
7776
keys = rope(keys)
7877
}
7978

80-
let output = MLXFast.scaledDotProductAttention(
81-
queries: queries, keys: keys, values: values, scale: self.scale, mask: mask
79+
let output = attentionWithCacheUpdate(
80+
queries: queries,
81+
keys: keys,
82+
values: values,
83+
cache: cache,
84+
scale: scale,
85+
mask: mask
8286
)
8387
.transposed(0, 2, 1, 3)
8488
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/Internlm2.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,18 @@ private class Attention: Module {
109109
if let cache {
110110
queries = rope(queries, offset: cache.offset)
111111
keys = rope(keys, offset: cache.offset)
112-
(keys, values) = cache.update(keys: keys, values: values)
113112
} else {
114113
queries = rope(queries)
115114
keys = rope(keys)
116115
}
117116

118-
let output = MLXFast.scaledDotProductAttention(
119-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
117+
let output = attentionWithCacheUpdate(
118+
queries: queries,
119+
keys: keys,
120+
values: values,
121+
cache: cache,
122+
scale: scale,
123+
mask: mask
120124
)
121125
.transposed(0, 2, 1, 3)
122126
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,18 @@ private class Attention: Module {
192192
if let cache {
193193
queries = rope(queries, offset: cache.offset)
194194
keys = rope(keys, offset: cache.offset)
195-
(keys, values) = cache.update(keys: keys, values: values)
196195
} else {
197196
queries = rope(queries)
198197
keys = rope(keys)
199198
}
200199

201-
let output = MLXFast.scaledDotProductAttention(
202-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
200+
let output = attentionWithCacheUpdate(
201+
queries: queries,
202+
keys: keys,
203+
values: values,
204+
cache: cache,
205+
scale: scale,
206+
mask: mask
203207
)
204208
.transposed(0, 2, 1, 3)
205209
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/MiMo.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ private class Attention: Module {
7171
if let cache {
7272
queries = rope(queries, offset: cache.offset)
7373
keys = rope(keys, offset: cache.offset)
74-
(keys, values) = cache.update(keys: keys, values: values)
7574
} else {
7675
queries = rope(queries)
7776
keys = rope(keys)
7877
}
7978

80-
let output = MLXFast.scaledDotProductAttention(
81-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
79+
let output = attentionWithCacheUpdate(
80+
queries: queries,
81+
keys: keys,
82+
values: values,
83+
cache: cache,
84+
scale: scale,
85+
mask: mask
8286
)
8387
.transposed(0, 2, 1, 3)
8488
.reshaped(B, L, -1)

Libraries/MLXLLM/Models/OpenELM.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,21 @@ private class MultiHeadCausalAttention: Module {
7979
if let cache {
8080
queries = rope(queries, offset: cache.offset)
8181
keys = rope(keys, offset: cache.offset)
82-
(keys, values) = cache.update(keys: keys, values: values)
8382
} else {
8483
queries = rope(queries)
8584
keys = rope(keys)
8685
}
8786

88-
let output = MLXFast.scaledDotProductAttention(
89-
queries: queries, keys: keys, values: values, scale: scale, mask: mask
90-
).transposed(0, 2, 1, 3).reshaped(B, L, heads * headDim)
87+
let output = attentionWithCacheUpdate(
88+
queries: queries,
89+
keys: keys,
90+
values: values,
91+
cache: cache,
92+
scale: scale,
93+
mask: mask
94+
)
95+
.transposed(0, 2, 1, 3)
96+
.reshaped(B, L, heads * headDim)
9197

9298
return outProj(output)
9399
}

Libraries/MLXLLM/Models/Phi.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,20 @@ private class PhiAttention: Module {
6060
if let cache {
6161
queries = rope(queries, offset: cache.offset)
6262
keys = rope(keys, offset: cache.offset)
63-
(keys, values) = cache.update(keys: keys, values: values)
6463
} else {
6564
queries = rope(queries)
6665
keys = rope(keys)
6766
}
6867

6968
// Finally perform the attention computation
7069
let scale = sqrt(1 / Float(queries.dim(-1)))
71-
let output = MLXFast.scaledDotProductAttention(
72-
queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask
70+
let output = attentionWithCacheUpdate(
71+
queries: queries.asType(.float32),
72+
keys: keys,
73+
values: values,
74+
cache: cache,
75+
scale: scale,
76+
mask: mask
7377
)
7478
.asType(values.dtype)
7579
.transposed(0, 2, 1, 3)

0 commit comments

Comments
 (0)