Skip to content

Commit 566c8ec

Browse files
Update Gemma and Gemma 2 to more closely follow Python implementations (#156)
* Fix Gemma 2 * Update Gemma to more closely follow Python implementation
1 parent 2dbe65d commit 566c8ec

File tree

2 files changed

+69
-68
lines changed

2 files changed

+69
-68
lines changed

Libraries/LLM/Models/Gemma.swift

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import MLX
55
import MLXFast
66
import MLXNN
77

8-
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
8+
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
99

10-
// specialized norm for gemma
10+
// Specialized norm for Gemma
1111
private class RMSNorm: Module, UnaryLayer {
1212
let weight: MLXArray
1313
let eps: Float
@@ -24,8 +24,10 @@ private class RMSNorm: Module, UnaryLayer {
2424
}
2525

2626
private class Attention: Module {
27-
2827
let args: GemmaConfiguration
28+
let nHeads: Int
29+
let nKVHeads: Int
30+
let headDim: Int
2931
let scale: Float
3032

3133
@ModuleInfo(key: "q_proj") var wq: Linear
@@ -39,16 +41,15 @@ private class Attention: Module {
3941
self.args = args
4042

4143
let dim = args.hiddenSize
42-
let heads = args.attentionHeads
43-
let kvHeads = args.kvHeads
44-
45-
let headDim = args.headDimensions
44+
self.nHeads = args.attentionHeads
45+
self.nKVHeads = args.kvHeads
46+
self.headDim = args.headDimensions
4647
self.scale = pow(Float(headDim), -0.5)
4748

48-
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
49-
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
50-
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
51-
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
49+
self._wq.wrappedValue = Linear(dim, nHeads * headDim, bias: false)
50+
self._wk.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
51+
self._wv.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
52+
self._wo.wrappedValue = Linear(nHeads * headDim, dim, bias: false)
5253

5354
self.rope = RoPE(
5455
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
@@ -63,10 +64,10 @@ private class Attention: Module {
6364
var keys = wk(x)
6465
var values = wv(x)
6566

66-
// prepare the queries, keys and values for the attention computation
67-
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
68-
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
69-
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
67+
// Prepare the queries, keys and values for the attention computation
68+
queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3)
69+
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
70+
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
7071

7172
if let cache {
7273
queries = rope(queries, offset: cache.offset)
@@ -88,7 +89,6 @@ private class Attention: Module {
8889
}
8990

9091
private class MLP: Module, UnaryLayer {
91-
9292
@ModuleInfo(key: "gate_proj") var gate: Linear
9393
@ModuleInfo(key: "down_proj") var down: Linear
9494
@ModuleInfo(key: "up_proj") var up: Linear
@@ -105,6 +105,8 @@ private class MLP: Module, UnaryLayer {
105105
}
106106

107107
private class TransformerBlock: Module {
108+
let numAttentionHeads: Int
109+
let hiddenSize: Int
108110

109111
@ModuleInfo(key: "self_attn") var attention: Attention
110112
let mlp: MLP
@@ -113,6 +115,9 @@ private class TransformerBlock: Module {
113115
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
114116

115117
public init(_ args: GemmaConfiguration) {
118+
self.numAttentionHeads = args.attentionHeads
119+
self.hiddenSize = args.hiddenSize
120+
116121
self._attention.wrappedValue = Attention(args)
117122
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
118123
self._inputLayerNorm.wrappedValue = RMSNorm(
@@ -127,28 +132,29 @@ private class TransformerBlock: Module {
127132
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
128133
let h = x + r
129134
r = mlp(postAttentionLayerNorm(h))
130-
let out = h + r
131-
return out
135+
return h + r
132136
}
133137
}
134138

135139
public class GemmaModelInner: Module {
140+
let args: GemmaConfiguration
141+
let vocabularySize: Int
142+
let numHiddenLayers: Int
136143

137144
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
138-
139145
fileprivate let layers: [TransformerBlock]
140146
fileprivate let norm: RMSNorm
141147

142-
let hiddenScale: Float
143-
144148
public init(_ args: GemmaConfiguration) {
145149
precondition(args.vocabularySize > 0)
146150

151+
self.args = args
152+
self.vocabularySize = args.vocabularySize
153+
self.numHiddenLayers = args.hiddenLayers
154+
147155
self._embedTokens.wrappedValue = Embedding(
148156
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
149157

150-
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
151-
152158
self.layers = (0 ..< args.hiddenLayers)
153159
.map { _ in
154160
TransformerBlock(args)
@@ -158,7 +164,7 @@ public class GemmaModelInner: Module {
158164

159165
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
160166
var h = embedTokens(inputs)
161-
h = h * hiddenScale
167+
h = h * pow(Float(args.hiddenSize), 0.5)
162168

163169
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
164170

@@ -171,29 +177,29 @@ public class GemmaModelInner: Module {
171177
}
172178

173179
public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider {
174-
175180
public let vocabularySize: Int
176181
public let kvHeads: [Int]
177182
public let headDim: IntOrPair
178183

184+
let modelType: String
179185
let model: GemmaModelInner
180186

181187
public init(_ args: GemmaConfiguration) {
188+
self.modelType = args.modelType
182189
self.vocabularySize = args.vocabularySize
183-
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
190+
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
184191
self.headDim = .init(args.headDimensions)
185192
self.model = GemmaModelInner(args)
186193
}
187194

188195
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
189-
var out = model(inputs, cache: cache)
190-
out = matmul(out, model.embedTokens.weight.T)
191-
return out
196+
let out = model(inputs, cache: cache)
197+
return model.embedTokens.asLinear(out)
192198
}
193199
}
194200

195201
public struct GemmaConfiguration: Codable, Sendable {
196-
202+
var modelType: String
197203
var hiddenSize: Int
198204
var hiddenLayers: Int
199205
var intermediateSize: Int
@@ -206,6 +212,7 @@ public struct GemmaConfiguration: Codable, Sendable {
206212
var ropeTraditional: Bool = false
207213

208214
enum CodingKeys: String, CodingKey {
215+
case modelType = "model_type"
209216
case hiddenSize = "hidden_size"
210217
case hiddenLayers = "num_hidden_layers"
211218
case intermediateSize = "intermediate_size"
@@ -219,10 +226,12 @@ public struct GemmaConfiguration: Codable, Sendable {
219226
}
220227

221228
public init(from decoder: Decoder) throws {
222-
// custom implementation to handle optional keys with required values
229+
// Custom implementation to handle optional keys with required values
223230
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
224231
keyedBy: CodingKeys.self)
225232

233+
self.modelType = try container.decode(
234+
String.self, forKey: CodingKeys.modelType)
226235
self.hiddenSize = try container.decode(
227236
Int.self, forKey: CodingKeys.hiddenSize)
228237
self.hiddenLayers = try container.decode(

Libraries/LLM/Models/Gemma2.swift

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import MLXNN
77

88
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
99

10-
// specialized norm for gemma
10+
// Specialized norm for gemma
1111
private class RMSNorm: Module, UnaryLayer {
1212
let weight: MLXArray
1313
let eps: Float
@@ -24,11 +24,13 @@ private class RMSNorm: Module, UnaryLayer {
2424
}
2525

2626
private class Attention: Module {
27-
2827
let args: Gemma2Configuration
2928
let scale: Float
3029
let logitSoftCap: Float
3130
let headDim: Int
31+
let nHeads: Int
32+
let nKVHeads: Int
33+
let repeats: Int
3234

3335
@ModuleInfo(key: "q_proj") var wq: Linear
3436
@ModuleInfo(key: "k_proj") var wk: Linear
@@ -41,19 +43,18 @@ private class Attention: Module {
4143
self.args = args
4244

4345
let dim = args.hiddenSize
44-
let heads = args.attentionHeads
45-
let kvHeads = args.kvHeads
46-
47-
let headDim = args.headDimensions
48-
self.headDim = headDim
49-
self.scale = pow(Float(args.queryPreAttnScalar), -0.5)
50-
self.logitSoftCap = args.attnLogitSoftcapping
46+
self.nHeads = args.attentionHeads
47+
self.nKVHeads = args.kvHeads
48+
self.repeats = args.attentionHeads / args.kvHeads
49+
self.headDim = args.headDimensions
5150

52-
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
53-
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
54-
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
55-
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
51+
self.scale = 1.0 / pow(Float(args.queryPreAttnScalar), 0.5)
5652

53+
self._wq.wrappedValue = Linear(dim, nHeads * headDim, bias: false)
54+
self._wk.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
55+
self._wv.wrappedValue = Linear(dim, nKVHeads * headDim, bias: false)
56+
self._wo.wrappedValue = Linear(nHeads * headDim, dim, bias: false)
57+
self.logitSoftCap = args.attnLogitSoftcapping
5758
self.rope = RoPE(
5859
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
5960
}
@@ -62,15 +63,12 @@ private class Attention: Module {
6263
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
6364
) -> MLXArray {
6465
let (B, L) = (x.dim(0), x.dim(1))
65-
6666
var queries = wq(x)
6767
var keys = wk(x)
6868
var values = wv(x)
69-
70-
// prepare the queries, keys and values for the attention computation
71-
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
72-
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
73-
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
69+
queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3)
70+
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
71+
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
7472

7573
if let cache {
7674
queries = rope(queries, offset: cache.offset)
@@ -81,33 +79,31 @@ private class Attention: Module {
8179
keys = rope(keys)
8280
}
8381

84-
let repeats = self.args.attentionHeads / self.args.kvHeads
82+
queries = queries * self.scale
83+
8584
if repeats > 1 {
86-
queries = queries.reshaped(
87-
[B, self.args.kvHeads, repeats, L, self.headDim]
88-
)
85+
queries = queries.reshaped([B, nKVHeads, repeats, L, headDim])
8986
keys = expandedDimensions(keys, axes: [2])
9087
values = expandedDimensions(values, axes: [2])
9188
}
9289

9390
var scores = matmul(queries, keys.swappedAxes(-1, -2))
94-
scores = tanh(scores / self.logitSoftCap) * self.logitSoftCap
91+
scores = tanh(scores / logitSoftCap) * logitSoftCap
9592

96-
if mask != nil {
97-
scores = scores + mask!
93+
if let mask {
94+
scores = scores + mask
9895
}
9996
scores = softmax(scores, axis: -1, precise: true)
10097
var output = matmul(scores, values)
10198
if repeats > 1 {
102-
output = output.reshaped([B, self.args.attentionHeads, L, self.headDim])
99+
output = output.reshaped([B, nHeads, L, headDim])
103100
}
104101
output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1)
105102
return wo(output)
106103
}
107104
}
108105

109106
private class MLP: Module, UnaryLayer {
110-
111107
@ModuleInfo(key: "gate_proj") var gate: Linear
112108
@ModuleInfo(key: "down_proj") var down: Linear
113109
@ModuleInfo(key: "up_proj") var up: Linear
@@ -125,7 +121,6 @@ private class MLP: Module, UnaryLayer {
125121

126122
// Minimal changes from Gemma TransformerBlock
127123
private class TransformerBlock: Module {
128-
129124
@ModuleInfo(key: "self_attn") var attention: Attention
130125
let mlp: MLP
131126

@@ -160,7 +155,6 @@ private class TransformerBlock: Module {
160155

161156
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
162157
public class ModelInner: Module {
163-
164158
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
165159

166160
fileprivate let layers: [TransformerBlock]
@@ -199,7 +193,6 @@ public class ModelInner: Module {
199193

200194
// Uses Gemma2ModelInner, otherwise same as GemmaModel
201195
public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
202-
203196
public let vocabularySize: Int
204197
public let kvHeads: [Int]
205198
public let headDim: IntOrPair
@@ -209,7 +202,7 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
209202

210203
public init(_ args: Gemma2Configuration) {
211204
self.vocabularySize = args.vocabularySize
212-
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
205+
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
213206
self.headDim = .init(args.headDimensions)
214207
self.model = ModelInner(args)
215208
self.logitSoftCap = args.finalLogitSoftcapping
@@ -218,13 +211,12 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
218211
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
219212
var out = model(inputs, cache: cache)
220213
out = model.embedTokens.asLinear(out)
221-
out = tanh(out / self.logitSoftCap) * self.logitSoftCap
214+
out = tanh(out / logitSoftCap) * logitSoftCap
222215
return out
223216
}
224217
}
225218

226219
public struct Gemma2Configuration: Codable {
227-
228220
var hiddenSize: Int
229221
var hiddenLayers: Int
230222
var intermediateSize: Int
@@ -237,7 +229,7 @@ public struct Gemma2Configuration: Codable {
237229
var ropeTraditional: Bool = false
238230
var attnLogitSoftcapping: Float = 50.0
239231
var finalLogitSoftcapping: Float = 30.0
240-
var queryPreAttnScalar: Int = 256
232+
var queryPreAttnScalar: Float = 144.0
241233

242234
enum CodingKeys: String, CodingKey {
243235
case hiddenSize = "hidden_size"
@@ -256,7 +248,7 @@ public struct Gemma2Configuration: Codable {
256248
}
257249

258250
public init(from decoder: Decoder) throws {
259-
// custom implementation to handle optional keys with required values
251+
// Custom implementation to handle optional keys with required values
260252
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
261253
keyedBy: CodingKeys.self)
262254

@@ -286,7 +278,7 @@ public struct Gemma2Configuration: Codable {
286278
self.finalLogitSoftcapping = try container.decode(
287279
Float.self, forKey: CodingKeys.finalLogitSoftcapping)
288280
self.queryPreAttnScalar = try container.decode(
289-
Int.self, forKey: CodingKeys.queryPreAttnScalar)
281+
Float.self, forKey: CodingKeys.queryPreAttnScalar)
290282
}
291283
}
292284

0 commit comments

Comments
 (0)