@@ -8,22 +8,6 @@ import Tokenizers
8
8
9
9
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
10
10
11
- // Specialized norm for gemma
12
- private class RMSNorm : Module , UnaryLayer {
13
- let weight : MLXArray
14
- let eps : Float
15
-
16
- public init ( dimensions: Int , eps: Float = 1e-5 ) {
17
- self . weight = MLXArray . ones ( [ dimensions] )
18
- self . eps = eps
19
- super. init ( )
20
- }
21
-
22
- public func callAsFunction( _ x: MLXArray ) -> MLXArray {
23
- return MLXFast . rmsNorm ( x, weight: 1.0 + self . weight, eps: self . eps)
24
- }
25
- }
26
-
27
11
private class Attention : Module {
28
12
let args : Gemma2Configuration
29
13
let scale : Float
@@ -125,21 +109,21 @@ private class TransformerBlock: Module {
125
109
@ModuleInfo ( key: " self_attn " ) var attention : Attention
126
110
let mlp : MLP
127
111
128
- @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : RMSNorm
129
- @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayerNorm : RMSNorm
130
- @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayerNorm : RMSNorm
131
- @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : RMSNorm
112
+ @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : Gemma . RMSNorm
113
+ @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayerNorm : Gemma . RMSNorm
114
+ @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayerNorm : Gemma . RMSNorm
115
+ @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : Gemma . RMSNorm
132
116
133
117
public init ( _ args: Gemma2Configuration ) {
134
118
self . _attention. wrappedValue = Attention ( args)
135
119
self . mlp = MLP ( dimensions: args. hiddenSize, hiddenDimensions: args. intermediateSize)
136
- self . _inputLayerNorm. wrappedValue = RMSNorm (
120
+ self . _inputLayerNorm. wrappedValue = Gemma . RMSNorm (
137
121
dimensions: args. hiddenSize, eps: args. rmsNormEps)
138
- self . _preFeedforwardLayerNorm. wrappedValue = RMSNorm (
122
+ self . _preFeedforwardLayerNorm. wrappedValue = Gemma . RMSNorm (
139
123
dimensions: args. hiddenSize, eps: args. rmsNormEps)
140
- self . _postFeedforwardLayerNorm. wrappedValue = RMSNorm (
124
+ self . _postFeedforwardLayerNorm. wrappedValue = Gemma . RMSNorm (
141
125
dimensions: args. hiddenSize, eps: args. rmsNormEps)
142
- self . _postAttentionLayerNorm. wrappedValue = RMSNorm (
126
+ self . _postAttentionLayerNorm. wrappedValue = Gemma . RMSNorm (
143
127
dimensions: args. hiddenSize, eps: args. rmsNormEps)
144
128
}
145
129
@@ -159,7 +143,7 @@ private class ModelInner: Module {
159
143
@ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
160
144
161
145
fileprivate let layers : [ TransformerBlock ]
162
- fileprivate let norm : RMSNorm
146
+ fileprivate let norm : Gemma . RMSNorm
163
147
164
148
let hiddenScale : Float
165
149
@@ -175,7 +159,7 @@ private class ModelInner: Module {
175
159
. map { _ in
176
160
TransformerBlock ( args)
177
161
}
178
- self . norm = RMSNorm ( dimensions: args. hiddenSize, eps: args. rmsNormEps)
162
+ self . norm = Gemma . RMSNorm ( dimensions: args. hiddenSize, eps: args. rmsNormEps)
179
163
}
180
164
181
165
public func callAsFunction( _ inputs: MLXArray , cache: [ KVCache ] ? = nil ) -> MLXArray {
0 commit comments