@@ -7,7 +7,7 @@ import MLXNN
7
7
8
8
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
9
9
10
- // specialized norm for gemma
10
+ // Specialized norm for gemma
11
11
private class RMSNorm : Module , UnaryLayer {
12
12
let weight : MLXArray
13
13
let eps : Float
@@ -24,11 +24,13 @@ private class RMSNorm: Module, UnaryLayer {
24
24
}
25
25
26
26
private class Attention : Module {
27
-
28
27
let args : Gemma2Configuration
29
28
let scale : Float
30
29
let logitSoftCap : Float
31
30
let headDim : Int
31
+ let nHeads : Int
32
+ let nKVHeads : Int
33
+ let repeats : Int
32
34
33
35
@ModuleInfo ( key: " q_proj " ) var wq : Linear
34
36
@ModuleInfo ( key: " k_proj " ) var wk : Linear
@@ -41,19 +43,18 @@ private class Attention: Module {
41
43
self . args = args
42
44
43
45
let dim = args. hiddenSize
44
- let heads = args. attentionHeads
45
- let kvHeads = args. kvHeads
46
-
47
- let headDim = args. headDimensions
48
- self . headDim = headDim
49
- self . scale = pow ( Float ( args. queryPreAttnScalar) , - 0.5 )
50
- self . logitSoftCap = args. attnLogitSoftcapping
46
+ self . nHeads = args. attentionHeads
47
+ self . nKVHeads = args. kvHeads
48
+ self . repeats = args. attentionHeads / args. kvHeads
49
+ self . headDim = args. headDimensions
51
50
52
- self . _wq. wrappedValue = Linear ( dim, heads * headDim, bias: false )
53
- self . _wk. wrappedValue = Linear ( dim, kvHeads * headDim, bias: false )
54
- self . _wv. wrappedValue = Linear ( dim, kvHeads * headDim, bias: false )
55
- self . _wo. wrappedValue = Linear ( heads * headDim, dim, bias: false )
51
+ self . scale = 1.0 / pow( Float ( args. queryPreAttnScalar) , 0.5 )
56
52
53
+ self . _wq. wrappedValue = Linear ( dim, nHeads * headDim, bias: false )
54
+ self . _wk. wrappedValue = Linear ( dim, nKVHeads * headDim, bias: false )
55
+ self . _wv. wrappedValue = Linear ( dim, nKVHeads * headDim, bias: false )
56
+ self . _wo. wrappedValue = Linear ( nHeads * headDim, dim, bias: false )
57
+ self . logitSoftCap = args. attnLogitSoftcapping
57
58
self . rope = RoPE (
58
59
dimensions: headDim, traditional: args. ropeTraditional, base: args. ropeTheta)
59
60
}
@@ -62,15 +63,12 @@ private class Attention: Module {
62
63
_ x: MLXArray , mask: MLXArray ? = nil , cache: KVCache ?
63
64
) -> MLXArray {
64
65
let ( B, L) = ( x. dim ( 0 ) , x. dim ( 1 ) )
65
-
66
66
var queries = wq ( x)
67
67
var keys = wk ( x)
68
68
var values = wv ( x)
69
-
70
- // prepare the queries, keys and values for the attention computation
71
- queries = queries. reshaped ( B, L, args. attentionHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
72
- keys = keys. reshaped ( B, L, args. kvHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
73
- values = values. reshaped ( B, L, args. kvHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
69
+ queries = queries. reshaped ( B, L, nHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
70
+ keys = keys. reshaped ( B, L, nKVHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
71
+ values = values. reshaped ( B, L, nKVHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
74
72
75
73
if let cache {
76
74
queries = rope ( queries, offset: cache. offset)
@@ -81,33 +79,31 @@ private class Attention: Module {
81
79
keys = rope ( keys)
82
80
}
83
81
84
- let repeats = self . args. attentionHeads / self . args. kvHeads
82
+ queries = queries * self . scale
83
+
85
84
if repeats > 1 {
86
- queries = queries. reshaped (
87
- [ B, self . args. kvHeads, repeats, L, self . headDim]
88
- )
85
+ queries = queries. reshaped ( [ B, nKVHeads, repeats, L, headDim] )
89
86
keys = expandedDimensions ( keys, axes: [ 2 ] )
90
87
values = expandedDimensions ( values, axes: [ 2 ] )
91
88
}
92
89
93
90
var scores = matmul ( queries, keys. swappedAxes ( - 1 , - 2 ) )
94
- scores = tanh ( scores / self . logitSoftCap) * self . logitSoftCap
91
+ scores = tanh ( scores / logitSoftCap) * logitSoftCap
95
92
96
- if mask != nil {
97
- scores = scores + mask!
93
+ if let mask {
94
+ scores = scores + mask
98
95
}
99
96
scores = softmax ( scores, axis: - 1 , precise: true )
100
97
var output = matmul ( scores, values)
101
98
if repeats > 1 {
102
- output = output. reshaped ( [ B, self . args . attentionHeads , L, self . headDim] )
99
+ output = output. reshaped ( [ B, nHeads , L, headDim] )
103
100
}
104
101
output = output. transposed ( 0 , 2 , 1 , 3 ) . reshaped ( B, L, - 1 )
105
102
return wo ( output)
106
103
}
107
104
}
108
105
109
106
private class MLP : Module , UnaryLayer {
110
-
111
107
@ModuleInfo ( key: " gate_proj " ) var gate : Linear
112
108
@ModuleInfo ( key: " down_proj " ) var down : Linear
113
109
@ModuleInfo ( key: " up_proj " ) var up : Linear
@@ -125,7 +121,6 @@ private class MLP: Module, UnaryLayer {
125
121
126
122
// Minimal changes from Gemma TransformerBlock
127
123
private class TransformerBlock : Module {
128
-
129
124
@ModuleInfo ( key: " self_attn " ) var attention : Attention
130
125
let mlp : MLP
131
126
@@ -160,7 +155,6 @@ private class TransformerBlock: Module {
160
155
161
156
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
162
157
public class ModelInner : Module {
163
-
164
158
@ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
165
159
166
160
fileprivate let layers : [ TransformerBlock ]
@@ -199,7 +193,6 @@ public class ModelInner: Module {
199
193
200
194
// Uses Gemma2ModelInner, otherwise same as GemmaModel
201
195
public class Gemma2Model : Module , LLMModel , KVCacheDimensionProvider {
202
-
203
196
public let vocabularySize : Int
204
197
public let kvHeads : [ Int ]
205
198
public let headDim : IntOrPair
@@ -209,7 +202,7 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
209
202
210
203
public init ( _ args: Gemma2Configuration ) {
211
204
self . vocabularySize = args. vocabularySize
212
- self . kvHeads = ( 0 ..< args. hiddenLayers ) . map { _ in args. kvHeads }
205
+ self . kvHeads = Array ( repeating : args. kvHeads , count : args. hiddenLayers )
213
206
self . headDim = . init( args. headDimensions)
214
207
self . model = ModelInner ( args)
215
208
self . logitSoftCap = args. finalLogitSoftcapping
@@ -218,13 +211,12 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
218
211
public func callAsFunction( _ inputs: MLXArray , cache: [ KVCache ] ? ) -> MLXArray {
219
212
var out = model ( inputs, cache: cache)
220
213
out = model. embedTokens. asLinear ( out)
221
- out = tanh ( out / self . logitSoftCap) * self . logitSoftCap
214
+ out = tanh ( out / logitSoftCap) * logitSoftCap
222
215
return out
223
216
}
224
217
}
225
218
226
219
public struct Gemma2Configuration : Codable {
227
-
228
220
var hiddenSize : Int
229
221
var hiddenLayers : Int
230
222
var intermediateSize : Int
@@ -237,7 +229,7 @@ public struct Gemma2Configuration: Codable {
237
229
var ropeTraditional : Bool = false
238
230
var attnLogitSoftcapping : Float = 50.0
239
231
var finalLogitSoftcapping : Float = 30.0
240
- var queryPreAttnScalar : Int = 256
232
+ var queryPreAttnScalar : Float = 144.0
241
233
242
234
enum CodingKeys : String , CodingKey {
243
235
case hiddenSize = " hidden_size "
@@ -256,7 +248,7 @@ public struct Gemma2Configuration: Codable {
256
248
}
257
249
258
250
public init ( from decoder: Decoder ) throws {
259
- // custom implementation to handle optional keys with required values
251
+ // Custom implementation to handle optional keys with required values
260
252
let container : KeyedDecodingContainer < CodingKeys > = try decoder. container (
261
253
keyedBy: CodingKeys . self)
262
254
@@ -286,7 +278,7 @@ public struct Gemma2Configuration: Codable {
286
278
self . finalLogitSoftcapping = try container. decode (
287
279
Float . self, forKey: CodingKeys . finalLogitSoftcapping)
288
280
self . queryPreAttnScalar = try container. decode (
289
- Int . self, forKey: CodingKeys . queryPreAttnScalar)
281
+ Float . self, forKey: CodingKeys . queryPreAttnScalar)
290
282
}
291
283
}
292
284
0 commit comments