Skip to content

Commit 2cceb87

Browse files
Add Gemma 3 (#238)
* Working on KV cache * Working on Gemma 3 * Add model configurations for all Gemma 3 sizes
1 parent fb3184c commit 2cceb87

File tree

9 files changed

+1718
-55
lines changed

9 files changed

+1718
-55
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4242
"cohere": create(CohereConfiguration.self, CohereModel.init),
4343
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4444
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
45+
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
4546
"granite": create(GraniteConfiguration.self, GraniteModel.init),
4647
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
4748
"glm4": create(GLM4Configuration.self, GLM4Model.init),
@@ -197,6 +198,12 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
197198
defaultPrompt: "What is the difference between a fruit and a vegetable?"
198199
)
199200

201+
static public let gemma3_1B_qat_4bit = ModelConfiguration(
202+
id: "mlx-community/gemma-3-1b-it-qat-4bit",
203+
defaultPrompt: "What is the difference between a fruit and a vegetable?",
204+
extraEOSTokens: ["<end_of_turn>"]
205+
)
206+
200207
static public let granite3_3_2b_4bit = ModelConfiguration(
201208
id: "mlx-community/granite-3.3-2b-instruct-4bit",
202209
defaultPrompt: ""
@@ -244,6 +251,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
244251
qwen3_8b_4bit,
245252
qwen3MoE_30b_a3b_4bit,
246253
smolLM_135M_4bit,
254+
gemma3_1B_qat_4bit,
247255
mimo_7b_sft_4bit,
248256
glm4_9b_4bit,
249257
acereason_7b_4bit,
@@ -275,6 +283,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
275283
do {
276284
let promptTokens = try tokenizer.applyChatTemplate(
277285
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
286+
278287
return LMInput(tokens: MLXArray(promptTokens))
279288
} catch TokenizerError.missingChatTemplate {
280289
print(

Libraries/MLXLLM/Models/Gemma.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ private class TransformerBlock: Module {
108108
@ModuleInfo(key: "self_attn") var attention: Attention
109109
let mlp: MLP
110110

111-
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
112-
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
111+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: Gemma.RMSNorm
112+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: Gemma.RMSNorm
113113

114114
public init(_ args: GemmaConfiguration) {
115115
self._attention.wrappedValue = Attention(args)
116116
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
117-
self._inputLayerNorm.wrappedValue = RMSNorm(
117+
self._inputLayerNorm.wrappedValue = Gemma.RMSNorm(
118118
dimensions: args.hiddenSize, eps: args.rmsNormEps)
119-
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
119+
self._postAttentionLayerNorm.wrappedValue = Gemma.RMSNorm(
120120
dimensions: args.hiddenSize, eps: args.rmsNormEps)
121121
}
122122

@@ -137,7 +137,7 @@ private class GemmaModelInner: Module {
137137

138138
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
139139
fileprivate let layers: [TransformerBlock]
140-
fileprivate let norm: RMSNorm
140+
fileprivate let norm: Gemma.RMSNorm
141141

142142
public init(_ args: GemmaConfiguration) {
143143
precondition(args.vocabularySize > 0)
@@ -153,7 +153,7 @@ private class GemmaModelInner: Module {
153153
.map { _ in
154154
TransformerBlock(args)
155155
}
156-
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
156+
self.norm = Gemma.RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
157157
}
158158

159159
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {

Libraries/MLXLLM/Models/Gemma2.swift

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,6 @@ import Tokenizers
88

99
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
1010

11-
// Specialized norm for gemma
12-
private class RMSNorm: Module, UnaryLayer {
13-
let weight: MLXArray
14-
let eps: Float
15-
16-
public init(dimensions: Int, eps: Float = 1e-5) {
17-
self.weight = MLXArray.ones([dimensions])
18-
self.eps = eps
19-
super.init()
20-
}
21-
22-
public func callAsFunction(_ x: MLXArray) -> MLXArray {
23-
return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps)
24-
}
25-
}
26-
2711
private class Attention: Module {
2812
let args: Gemma2Configuration
2913
let scale: Float
@@ -125,21 +109,21 @@ private class TransformerBlock: Module {
125109
@ModuleInfo(key: "self_attn") var attention: Attention
126110
let mlp: MLP
127111

128-
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
129-
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
130-
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
131-
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
112+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: Gemma.RMSNorm
113+
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: Gemma.RMSNorm
114+
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: Gemma.RMSNorm
115+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: Gemma.RMSNorm
132116

133117
public init(_ args: Gemma2Configuration) {
134118
self._attention.wrappedValue = Attention(args)
135119
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
136-
self._inputLayerNorm.wrappedValue = RMSNorm(
120+
self._inputLayerNorm.wrappedValue = Gemma.RMSNorm(
137121
dimensions: args.hiddenSize, eps: args.rmsNormEps)
138-
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
122+
self._preFeedforwardLayerNorm.wrappedValue = Gemma.RMSNorm(
139123
dimensions: args.hiddenSize, eps: args.rmsNormEps)
140-
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
124+
self._postFeedforwardLayerNorm.wrappedValue = Gemma.RMSNorm(
141125
dimensions: args.hiddenSize, eps: args.rmsNormEps)
142-
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
126+
self._postAttentionLayerNorm.wrappedValue = Gemma.RMSNorm(
143127
dimensions: args.hiddenSize, eps: args.rmsNormEps)
144128
}
145129

@@ -159,7 +143,7 @@ private class ModelInner: Module {
159143
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
160144

161145
fileprivate let layers: [TransformerBlock]
162-
fileprivate let norm: RMSNorm
146+
fileprivate let norm: Gemma.RMSNorm
163147

164148
let hiddenScale: Float
165149

@@ -175,7 +159,7 @@ private class ModelInner: Module {
175159
.map { _ in
176160
TransformerBlock(args)
177161
}
178-
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
162+
self.norm = Gemma.RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
179163
}
180164

181165
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {

0 commit comments

Comments
 (0)