Skip to content

Commit 45b2956

Browse files
authored
✨ feat: Add support for GLM4 model (#302)
1 parent f17047c commit 45b2956

File tree

4 files changed

+289
-0
lines changed

4 files changed

+289
-0
lines changed

Libraries/MLXLLM/Documentation.docc/Documentation.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ Example implementations of various Large Language Models (LLMs).
3030
- ``Qwen2Model``
3131
- ``Qwen3Model``
3232
- ``Starcoder2Model``
33+
- ``GLM4Model``
3334

3435

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4343
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4444
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
4545
"granite": create(GraniteConfiguration.self, GraniteModel.init),
46+
"glm4": create(GLM4Configuration.self, GLM4Model.init),
4647
]
4748
}
4849

@@ -199,6 +200,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
199200
defaultPrompt: ""
200201
)
201202

203+
static public let glm4_9b_4bit = ModelConfiguration(
204+
id: "mlx-community/GLM-4-9B-0414-4bit",
205+
defaultPrompt: "Why is the sky blue?"
206+
)
207+
202208
private static func all() -> [ModelConfiguration] {
203209
[
204210
codeLlama13b4bit,
@@ -225,6 +231,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
225231
qwen3_4b_4bit,
226232
qwen3_8b_4bit,
227233
smolLM_135M_4bit,
234+
glm4_9b_4bit,
228235
]
229236
}
230237

Libraries/MLXLLM/Models/GLM4.swift

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
//
2+
// GLM4.swift
3+
// LLM
4+
//
5+
// Created by John Mai on 2025/5/1.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXLMCommon
12+
import MLXNN
13+
14+
// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/glm4.py
15+
16+
private class Attention: Module {
17+
let args: GLM4Configuration
18+
let scale: Float
19+
20+
@ModuleInfo(key: "q_proj") var wq: Linear
21+
@ModuleInfo(key: "k_proj") var wk: Linear
22+
@ModuleInfo(key: "v_proj") var wv: Linear
23+
@ModuleInfo(key: "o_proj") var wo: Linear
24+
25+
let rope: RoPE
26+
27+
public init(_ args: GLM4Configuration) {
28+
self.args = args
29+
30+
let headDim = args.headDim > 0 ? args.headDim : args.hiddenSize / args.attentionHeads
31+
self.scale = pow(Float(headDim), -0.5)
32+
33+
_wq.wrappedValue = Linear(
34+
args.hiddenSize, args.attentionHeads * headDim, bias: args.attentionBias)
35+
_wk.wrappedValue = Linear(args.hiddenSize, args.kvHeads * headDim, bias: args.attentionBias)
36+
_wv.wrappedValue = Linear(args.hiddenSize, args.kvHeads * headDim, bias: args.attentionBias)
37+
_wo.wrappedValue = Linear(args.attentionHeads * headDim, args.hiddenSize, bias: false)
38+
39+
self.rope = RoPE(
40+
dimensions: Int(Float(headDim) * args.partialRotaryFactor),
41+
traditional: args.ropeTraditional, base: args.ropeTheta)
42+
}
43+
44+
public func callAsFunction(
45+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
46+
) -> MLXArray {
47+
let (B, L) = (x.dim(0), x.dim(1))
48+
49+
var queries = wq(x)
50+
var keys = wk(x)
51+
var values = wv(x)
52+
53+
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
54+
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
55+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
56+
57+
if let cache {
58+
queries = rope(queries, offset: cache.offset)
59+
keys = rope(keys, offset: cache.offset)
60+
(keys, values) = cache.update(
61+
keys: keys, values: values)
62+
} else {
63+
queries = rope(queries)
64+
keys = rope(keys)
65+
}
66+
67+
let output = MLXFast.scaledDotProductAttention(
68+
queries: queries, keys: keys, values: values, scale: scale,
69+
mask: mask
70+
)
71+
.transposed(0, 2, 1, 3)
72+
.reshaped(B, L, -1)
73+
74+
return wo(output)
75+
}
76+
}
77+
78+
private class MLP: Module, UnaryLayer {
79+
@ModuleInfo(key: "gate_up_proj") var gateUp: Linear
80+
@ModuleInfo(key: "down_proj") var down: Linear
81+
82+
public init(_ args: GLM4Configuration) {
83+
_gateUp.wrappedValue = Linear(args.hiddenSize, 2 * args.intermediateSize, bias: false)
84+
_down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: false)
85+
}
86+
87+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
88+
let x = gateUp(x)
89+
let chunks = split(x, parts: 2, axis: -1)
90+
return down(silu(chunks[0]) * chunks[1])
91+
}
92+
}
93+
94+
private class GLM4DecoderLayer: Module {
95+
@ModuleInfo(key: "self_attn") var attention: Attention
96+
let mlp: MLP
97+
98+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
99+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
100+
@ModuleInfo(key: "post_self_attn_layernorm") var postSelfAttnLayerNorm: RMSNorm
101+
@ModuleInfo(key: "post_mlp_layernorm") var postMlpLayerNorm: RMSNorm
102+
103+
public init(_ args: GLM4Configuration) {
104+
_attention.wrappedValue = Attention(args)
105+
self.mlp = MLP(args)
106+
_inputLayerNorm.wrappedValue = RMSNorm(
107+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
108+
_postAttentionLayerNorm.wrappedValue = RMSNorm(
109+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
110+
_postSelfAttnLayerNorm.wrappedValue = RMSNorm(
111+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
112+
_postMlpLayerNorm.wrappedValue = RMSNorm(
113+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
114+
}
115+
116+
public func callAsFunction(
117+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
118+
) -> MLXArray {
119+
var x =
120+
x
121+
+ postSelfAttnLayerNorm(
122+
attention(inputLayerNorm(x), mask: mask, cache: cache)
123+
)
124+
let residual = x
125+
x = postMlpLayerNorm(mlp(postAttentionLayerNorm(x))) + residual
126+
return x
127+
}
128+
}
129+
130+
private class GLM4ModelInner: Module {
131+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
132+
133+
fileprivate let layers: [GLM4DecoderLayer]
134+
let norm: RMSNorm
135+
136+
public init(_ args: GLM4Configuration) {
137+
precondition(args.vocabularySize > 0)
138+
139+
_embedTokens.wrappedValue = Embedding(
140+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
141+
142+
self.layers = (0 ..< args.hiddenLayers)
143+
.map { _ in
144+
GLM4DecoderLayer(args)
145+
}
146+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
147+
}
148+
149+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
150+
var h = embedTokens(inputs)
151+
152+
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
153+
154+
for (i, layer) in layers.enumerated() {
155+
h = layer(h, mask: mask, cache: cache?[i])
156+
}
157+
158+
return norm(h)
159+
}
160+
}
161+
162+
public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider {
163+
public let vocabularySize: Int
164+
public let kvHeads: [Int]
165+
166+
private let model: GLM4ModelInner
167+
let configuration: GLM4Configuration
168+
let modelType: String
169+
170+
@ModuleInfo(key: "lm_head") var lmHead: Linear
171+
172+
public init(_ args: GLM4Configuration) {
173+
self.configuration = args
174+
self.vocabularySize = args.vocabularySize
175+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
176+
self.modelType = args.modelType
177+
self.model = GLM4ModelInner(args)
178+
179+
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
180+
}
181+
182+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
183+
let out = model(inputs, cache: cache)
184+
return lmHead(out)
185+
}
186+
187+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
188+
var weights = weights
189+
190+
if configuration.tieWordEmbeddings {
191+
weights["lm_head.weight"] = nil
192+
}
193+
194+
return weights
195+
}
196+
}
197+
198+
public struct GLM4Configuration: Codable, Sendable {
199+
var hiddenSize: Int
200+
var hiddenLayers: Int
201+
var intermediateSize: Int
202+
var attentionHeads: Int
203+
var attentionBias: Bool
204+
var headDim: Int
205+
var rmsNormEps: Float
206+
var vocabularySize: Int
207+
var kvHeads: Int
208+
var partialRotaryFactor: Float
209+
var ropeTheta: Float = 10000.0
210+
var ropeTraditional: Bool = true
211+
var tieWordEmbeddings = false
212+
var maxPositionEmbeddings: Int = 32768
213+
var modelType: String
214+
215+
enum CodingKeys: String, CodingKey {
216+
case hiddenSize = "hidden_size"
217+
case hiddenLayers = "num_hidden_layers"
218+
case intermediateSize = "intermediate_size"
219+
case attentionHeads = "num_attention_heads"
220+
case attentionBias = "attention_bias"
221+
case headDim = "head_dim"
222+
case rmsNormEps = "rms_norm_eps"
223+
case vocabularySize = "vocab_size"
224+
case kvHeads = "num_key_value_heads"
225+
case partialRotaryFactor = "partial_rotary_factor"
226+
case ropeTheta = "rope_theta"
227+
case ropeTraditional = "rope_traditional"
228+
case tieWordEmbeddings = "tie_word_embeddings"
229+
case maxPositionEmbeddings = "max_position_embeddings"
230+
case modelType = "model_type"
231+
}
232+
233+
public init(from decoder: Decoder) throws {
234+
let container: KeyedDecodingContainer<GLM4Configuration.CodingKeys> =
235+
try decoder.container(
236+
keyedBy: GLM4Configuration.CodingKeys.self)
237+
238+
self.modelType = try container.decode(
239+
String.self, forKey: GLM4Configuration.CodingKeys.modelType)
240+
self.hiddenSize = try container.decode(
241+
Int.self, forKey: GLM4Configuration.CodingKeys.hiddenSize)
242+
self.hiddenLayers = try container.decode(
243+
Int.self, forKey: GLM4Configuration.CodingKeys.hiddenLayers)
244+
self.intermediateSize = try container.decode(
245+
Int.self, forKey: GLM4Configuration.CodingKeys.intermediateSize)
246+
self.attentionHeads = try container.decode(
247+
Int.self, forKey: GLM4Configuration.CodingKeys.attentionHeads)
248+
self.attentionBias = try container.decode(
249+
Bool.self, forKey: GLM4Configuration.CodingKeys.attentionBias)
250+
self.headDim = try container.decode(
251+
Int.self, forKey: GLM4Configuration.CodingKeys.headDim)
252+
self.rmsNormEps = try container.decode(
253+
Float.self, forKey: GLM4Configuration.CodingKeys.rmsNormEps)
254+
self.vocabularySize = try container.decode(
255+
Int.self, forKey: GLM4Configuration.CodingKeys.vocabularySize)
256+
self.kvHeads = try container.decode(Int.self, forKey: GLM4Configuration.CodingKeys.kvHeads)
257+
self.partialRotaryFactor = try container.decode(
258+
Float.self, forKey: GLM4Configuration.CodingKeys.partialRotaryFactor)
259+
self.ropeTheta =
260+
try container.decodeIfPresent(
261+
Float.self, forKey: GLM4Configuration.CodingKeys.ropeTheta)
262+
?? 10000.0
263+
self.ropeTraditional =
264+
try container.decodeIfPresent(
265+
Bool.self, forKey: GLM4Configuration.CodingKeys.ropeTraditional)
266+
?? true
267+
self.tieWordEmbeddings =
268+
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
269+
self.maxPositionEmbeddings =
270+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
271+
}
272+
}
273+
274+
// MARK: - LoRA
275+
276+
extension GLM4Model: LoRAModel {
277+
public func loraLinearLayers() -> LoRALinearLayers {
278+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
279+
}
280+
}

Libraries/MLXLLM/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently supported model types are:
5656
- Qwen2
5757
- Qwen3
5858
- Starcoder2
59+
- GLM4
5960

6061
See [llm-tool](../../Tools/llm-tool)
6162

0 commit comments

Comments
 (0)