|
| 1 | +// |
| 2 | +// GLM4.swift |
| 3 | +// LLM |
| 4 | +// |
| 5 | +// Created by John Mai on 2025/5/1. |
| 6 | +// |
| 7 | + |
| 8 | +import Foundation |
| 9 | +import MLX |
| 10 | +import MLXFast |
| 11 | +import MLXLMCommon |
| 12 | +import MLXNN |
| 13 | + |
| 14 | +// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/glm4.py |
| 15 | + |
| 16 | +private class Attention: Module { |
| 17 | + let args: GLM4Configuration |
| 18 | + let scale: Float |
| 19 | + |
| 20 | + @ModuleInfo(key: "q_proj") var wq: Linear |
| 21 | + @ModuleInfo(key: "k_proj") var wk: Linear |
| 22 | + @ModuleInfo(key: "v_proj") var wv: Linear |
| 23 | + @ModuleInfo(key: "o_proj") var wo: Linear |
| 24 | + |
| 25 | + let rope: RoPE |
| 26 | + |
| 27 | + public init(_ args: GLM4Configuration) { |
| 28 | + self.args = args |
| 29 | + |
| 30 | + let headDim = args.headDim > 0 ? args.headDim : args.hiddenSize / args.attentionHeads |
| 31 | + self.scale = pow(Float(headDim), -0.5) |
| 32 | + |
| 33 | + _wq.wrappedValue = Linear( |
| 34 | + args.hiddenSize, args.attentionHeads * headDim, bias: args.attentionBias) |
| 35 | + _wk.wrappedValue = Linear(args.hiddenSize, args.kvHeads * headDim, bias: args.attentionBias) |
| 36 | + _wv.wrappedValue = Linear(args.hiddenSize, args.kvHeads * headDim, bias: args.attentionBias) |
| 37 | + _wo.wrappedValue = Linear(args.attentionHeads * headDim, args.hiddenSize, bias: false) |
| 38 | + |
| 39 | + self.rope = RoPE( |
| 40 | + dimensions: Int(Float(headDim) * args.partialRotaryFactor), |
| 41 | + traditional: args.ropeTraditional, base: args.ropeTheta) |
| 42 | + } |
| 43 | + |
| 44 | + public func callAsFunction( |
| 45 | + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? |
| 46 | + ) -> MLXArray { |
| 47 | + let (B, L) = (x.dim(0), x.dim(1)) |
| 48 | + |
| 49 | + var queries = wq(x) |
| 50 | + var keys = wk(x) |
| 51 | + var values = wv(x) |
| 52 | + |
| 53 | + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) |
| 54 | + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) |
| 55 | + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) |
| 56 | + |
| 57 | + if let cache { |
| 58 | + queries = rope(queries, offset: cache.offset) |
| 59 | + keys = rope(keys, offset: cache.offset) |
| 60 | + (keys, values) = cache.update( |
| 61 | + keys: keys, values: values) |
| 62 | + } else { |
| 63 | + queries = rope(queries) |
| 64 | + keys = rope(keys) |
| 65 | + } |
| 66 | + |
| 67 | + let output = MLXFast.scaledDotProductAttention( |
| 68 | + queries: queries, keys: keys, values: values, scale: scale, |
| 69 | + mask: mask |
| 70 | + ) |
| 71 | + .transposed(0, 2, 1, 3) |
| 72 | + .reshaped(B, L, -1) |
| 73 | + |
| 74 | + return wo(output) |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +private class MLP: Module, UnaryLayer { |
| 79 | + @ModuleInfo(key: "gate_up_proj") var gateUp: Linear |
| 80 | + @ModuleInfo(key: "down_proj") var down: Linear |
| 81 | + |
| 82 | + public init(_ args: GLM4Configuration) { |
| 83 | + _gateUp.wrappedValue = Linear(args.hiddenSize, 2 * args.intermediateSize, bias: false) |
| 84 | + _down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: false) |
| 85 | + } |
| 86 | + |
| 87 | + public func callAsFunction(_ x: MLXArray) -> MLXArray { |
| 88 | + let x = gateUp(x) |
| 89 | + let chunks = split(x, parts: 2, axis: -1) |
| 90 | + return down(silu(chunks[0]) * chunks[1]) |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +private class GLM4DecoderLayer: Module { |
| 95 | + @ModuleInfo(key: "self_attn") var attention: Attention |
| 96 | + let mlp: MLP |
| 97 | + |
| 98 | + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm |
| 99 | + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm |
| 100 | + @ModuleInfo(key: "post_self_attn_layernorm") var postSelfAttnLayerNorm: RMSNorm |
| 101 | + @ModuleInfo(key: "post_mlp_layernorm") var postMlpLayerNorm: RMSNorm |
| 102 | + |
| 103 | + public init(_ args: GLM4Configuration) { |
| 104 | + _attention.wrappedValue = Attention(args) |
| 105 | + self.mlp = MLP(args) |
| 106 | + _inputLayerNorm.wrappedValue = RMSNorm( |
| 107 | + dimensions: args.hiddenSize, eps: args.rmsNormEps) |
| 108 | + _postAttentionLayerNorm.wrappedValue = RMSNorm( |
| 109 | + dimensions: args.hiddenSize, eps: args.rmsNormEps) |
| 110 | + _postSelfAttnLayerNorm.wrappedValue = RMSNorm( |
| 111 | + dimensions: args.hiddenSize, eps: args.rmsNormEps) |
| 112 | + _postMlpLayerNorm.wrappedValue = RMSNorm( |
| 113 | + dimensions: args.hiddenSize, eps: args.rmsNormEps) |
| 114 | + } |
| 115 | + |
| 116 | + public func callAsFunction( |
| 117 | + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? |
| 118 | + ) -> MLXArray { |
| 119 | + var x = |
| 120 | + x |
| 121 | + + postSelfAttnLayerNorm( |
| 122 | + attention(inputLayerNorm(x), mask: mask, cache: cache) |
| 123 | + ) |
| 124 | + let residual = x |
| 125 | + x = postMlpLayerNorm(mlp(postAttentionLayerNorm(x))) + residual |
| 126 | + return x |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +private class GLM4ModelInner: Module { |
| 131 | + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding |
| 132 | + |
| 133 | + fileprivate let layers: [GLM4DecoderLayer] |
| 134 | + let norm: RMSNorm |
| 135 | + |
| 136 | + public init(_ args: GLM4Configuration) { |
| 137 | + precondition(args.vocabularySize > 0) |
| 138 | + |
| 139 | + _embedTokens.wrappedValue = Embedding( |
| 140 | + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) |
| 141 | + |
| 142 | + self.layers = (0 ..< args.hiddenLayers) |
| 143 | + .map { _ in |
| 144 | + GLM4DecoderLayer(args) |
| 145 | + } |
| 146 | + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) |
| 147 | + } |
| 148 | + |
| 149 | + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { |
| 150 | + var h = embedTokens(inputs) |
| 151 | + |
| 152 | + let mask: MLXArray? = createAttentionMask(h: h, cache: cache) |
| 153 | + |
| 154 | + for (i, layer) in layers.enumerated() { |
| 155 | + h = layer(h, mask: mask, cache: cache?[i]) |
| 156 | + } |
| 157 | + |
| 158 | + return norm(h) |
| 159 | + } |
| 160 | +} |
| 161 | + |
| 162 | +public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider { |
| 163 | + public let vocabularySize: Int |
| 164 | + public let kvHeads: [Int] |
| 165 | + |
| 166 | + private let model: GLM4ModelInner |
| 167 | + let configuration: GLM4Configuration |
| 168 | + let modelType: String |
| 169 | + |
| 170 | + @ModuleInfo(key: "lm_head") var lmHead: Linear |
| 171 | + |
| 172 | + public init(_ args: GLM4Configuration) { |
| 173 | + self.configuration = args |
| 174 | + self.vocabularySize = args.vocabularySize |
| 175 | + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } |
| 176 | + self.modelType = args.modelType |
| 177 | + self.model = GLM4ModelInner(args) |
| 178 | + |
| 179 | + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) |
| 180 | + } |
| 181 | + |
| 182 | + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { |
| 183 | + let out = model(inputs, cache: cache) |
| 184 | + return lmHead(out) |
| 185 | + } |
| 186 | + |
| 187 | + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { |
| 188 | + var weights = weights |
| 189 | + |
| 190 | + if configuration.tieWordEmbeddings { |
| 191 | + weights["lm_head.weight"] = nil |
| 192 | + } |
| 193 | + |
| 194 | + return weights |
| 195 | + } |
| 196 | +} |
| 197 | + |
| 198 | +public struct GLM4Configuration: Codable, Sendable { |
| 199 | + var hiddenSize: Int |
| 200 | + var hiddenLayers: Int |
| 201 | + var intermediateSize: Int |
| 202 | + var attentionHeads: Int |
| 203 | + var attentionBias: Bool |
| 204 | + var headDim: Int |
| 205 | + var rmsNormEps: Float |
| 206 | + var vocabularySize: Int |
| 207 | + var kvHeads: Int |
| 208 | + var partialRotaryFactor: Float |
| 209 | + var ropeTheta: Float = 10000.0 |
| 210 | + var ropeTraditional: Bool = true |
| 211 | + var tieWordEmbeddings = false |
| 212 | + var maxPositionEmbeddings: Int = 32768 |
| 213 | + var modelType: String |
| 214 | + |
| 215 | + enum CodingKeys: String, CodingKey { |
| 216 | + case hiddenSize = "hidden_size" |
| 217 | + case hiddenLayers = "num_hidden_layers" |
| 218 | + case intermediateSize = "intermediate_size" |
| 219 | + case attentionHeads = "num_attention_heads" |
| 220 | + case attentionBias = "attention_bias" |
| 221 | + case headDim = "head_dim" |
| 222 | + case rmsNormEps = "rms_norm_eps" |
| 223 | + case vocabularySize = "vocab_size" |
| 224 | + case kvHeads = "num_key_value_heads" |
| 225 | + case partialRotaryFactor = "partial_rotary_factor" |
| 226 | + case ropeTheta = "rope_theta" |
| 227 | + case ropeTraditional = "rope_traditional" |
| 228 | + case tieWordEmbeddings = "tie_word_embeddings" |
| 229 | + case maxPositionEmbeddings = "max_position_embeddings" |
| 230 | + case modelType = "model_type" |
| 231 | + } |
| 232 | + |
| 233 | + public init(from decoder: Decoder) throws { |
| 234 | + let container: KeyedDecodingContainer<GLM4Configuration.CodingKeys> = |
| 235 | + try decoder.container( |
| 236 | + keyedBy: GLM4Configuration.CodingKeys.self) |
| 237 | + |
| 238 | + self.modelType = try container.decode( |
| 239 | + String.self, forKey: GLM4Configuration.CodingKeys.modelType) |
| 240 | + self.hiddenSize = try container.decode( |
| 241 | + Int.self, forKey: GLM4Configuration.CodingKeys.hiddenSize) |
| 242 | + self.hiddenLayers = try container.decode( |
| 243 | + Int.self, forKey: GLM4Configuration.CodingKeys.hiddenLayers) |
| 244 | + self.intermediateSize = try container.decode( |
| 245 | + Int.self, forKey: GLM4Configuration.CodingKeys.intermediateSize) |
| 246 | + self.attentionHeads = try container.decode( |
| 247 | + Int.self, forKey: GLM4Configuration.CodingKeys.attentionHeads) |
| 248 | + self.attentionBias = try container.decode( |
| 249 | + Bool.self, forKey: GLM4Configuration.CodingKeys.attentionBias) |
| 250 | + self.headDim = try container.decode( |
| 251 | + Int.self, forKey: GLM4Configuration.CodingKeys.headDim) |
| 252 | + self.rmsNormEps = try container.decode( |
| 253 | + Float.self, forKey: GLM4Configuration.CodingKeys.rmsNormEps) |
| 254 | + self.vocabularySize = try container.decode( |
| 255 | + Int.self, forKey: GLM4Configuration.CodingKeys.vocabularySize) |
| 256 | + self.kvHeads = try container.decode(Int.self, forKey: GLM4Configuration.CodingKeys.kvHeads) |
| 257 | + self.partialRotaryFactor = try container.decode( |
| 258 | + Float.self, forKey: GLM4Configuration.CodingKeys.partialRotaryFactor) |
| 259 | + self.ropeTheta = |
| 260 | + try container.decodeIfPresent( |
| 261 | + Float.self, forKey: GLM4Configuration.CodingKeys.ropeTheta) |
| 262 | + ?? 10000.0 |
| 263 | + self.ropeTraditional = |
| 264 | + try container.decodeIfPresent( |
| 265 | + Bool.self, forKey: GLM4Configuration.CodingKeys.ropeTraditional) |
| 266 | + ?? true |
| 267 | + self.tieWordEmbeddings = |
| 268 | + try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false |
| 269 | + self.maxPositionEmbeddings = |
| 270 | + try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768 |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +// MARK: - LoRA |
| 275 | + |
| 276 | +extension GLM4Model: LoRAModel { |
| 277 | + public func loraLinearLayers() -> LoRALinearLayers { |
| 278 | + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } |
| 279 | + } |
| 280 | +} |
0 commit comments