From c835489902e19ceb12caf3b505e7c6fdfa48acd6 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sat, 5 Jul 2025 00:06:26 +0800 Subject: [PATCH 1/2] feat: Add SmolLM3 --- Libraries/MLXLLM/LLMModelFactory.swift | 1 + Libraries/MLXLLM/Models/SmolLM3.swift | 370 +++++++++++++++++++++++++ 2 files changed, 371 insertions(+) create mode 100644 Libraries/MLXLLM/Models/SmolLM3.swift diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index f940a13b..820da8d7 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -49,6 +49,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { "glm4": create(GLM4Configuration.self, GLM4Model.init), "acereason": create(Qwen2Configuration.self, Qwen2Model.init), "bitnet": create(BitnetConfiguration.self, BitnetModel.init), + "smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init), ] } diff --git a/Libraries/MLXLLM/Models/SmolLM3.swift b/Libraries/MLXLLM/Models/SmolLM3.swift new file mode 100644 index 00000000..1e388e77 --- /dev/null +++ b/Libraries/MLXLLM/Models/SmolLM3.swift @@ -0,0 +1,370 @@ +// +// SmolLM3.swift +// mlx-swift-examples +// +// Created by John Mai on 2025/7/4. +// + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +private protocol PositionEmbedding { + func callAsFunction(_ x: MLXArray, offset: Int) -> MLXArray + func callAsFunction(_ x: MLXArray) -> MLXArray +} + +extension RoPE: PositionEmbedding {} + +// MARK: - NoPE + +private final class NoPE: Module, PositionEmbedding { + func callAsFunction(_ x: MLXArray, offset: Int) -> MLXArray { + return x + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + callAsFunction(x, offset: 0) + } +} + +// MARK: - Attention + +private class Attention: Module { + let args: SmolLM3Configuration + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + var rope: PositionEmbedding + + init(_ args: SmolLM3Configuration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.resolvedHeadDimensions + self.scale = pow(Float(headDim), -0.5) + + self.rope = RoPE( + dimensions: headDim, + traditional: args.ropeTraditional, + base: args.ropeTheta, + scale: 1.0 + ) + + _wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias) + _wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + _wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + _wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = attentionWithCacheUpdate( + queries: queries, + keys: keys, + values: values, + cache: cache, + scale: scale, + mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } +} + +// MARK: - MLP + +private class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + init(_ args: SmolLM3Configuration) { + _gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) + _down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias) + _up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let activation = silu(gate(x)) + return down(activation * up(x)) + } +} + +private class TransformerBlock: Module { + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "mlp") var mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + init(_ args: SmolLM3Configuration) { + _attention.wrappedValue = Attention(args) + _mlp.wrappedValue = MLP(args) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + _postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } +} + +// MARK: - Model + +private class SmolLM3ModelInner: Module { + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + let norm: RMSNorm + + init(_ args: SmolLM3Configuration) { + precondition(args.vocabularySize > 0) + + _embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + var h = embedTokens(inputs) + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } +} + +public class SmolLM3Model: Module, LLMModel, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + + private let model: SmolLM3ModelInner + let configuration: SmolLM3Configuration + + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + public init(_ args: SmolLM3Configuration) { + self.configuration = args + self.vocabularySize = args.vocabularySize + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + + self.model = SmolLM3ModelInner(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + super.init() + + let identityRope = NoPE() + for (idx, useRope) in args.noRopeLayers.enumerated() { + if useRope == 0 && idx < model.layers.count { + model.layers[idx].attention.rope = identityRope + } + } + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + var out = model(inputs, cache: cache) + if let lmHead { + out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) + } + return out + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var weights = weights + + weights = weights.filter { key, _ in + !key.contains("self_attn.rotary_emb.inv_freq") + } + + if configuration.tieWordEmbeddings { + weights["lm_head.weight"] = nil + } + + return weights + } +} + +// MARK: - Configuration + +public struct SmolLM3Configuration: Codable, Sendable { + var modelType: String + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var headDimensions: Int? + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var maxPositionEmbeddings: Int? + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var ropeScaling: [String: StringOrNumber]? + var tieWordEmbeddings: Bool = true + var attentionBias: Bool = false + var mlpBias: Bool = false + + var noRopeLayerInterval: Int = 4 + var noRopeLayers: [Int] = [] + + var resolvedHeadDimensions: Int { + headDimensions ?? (hiddenSize / attentionHeads) + } + + public init( + modelType: String = "smollm3", + hiddenSize: Int, + hiddenLayers: Int, + intermediateSize: Int, + attentionHeads: Int, + headDimensions: Int? = nil, + rmsNormEps: Float, + vocabularySize: Int, + kvHeads: Int, + maxPositionEmbeddings: Int? = nil, + ropeTheta: Float = 10_000, + ropeTraditional: Bool = false, + ropeScaling: [String: StringOrNumber]? = nil, + tieWordEmbeddings: Bool = true, + attentionBias: Bool = false, + mlpBias: Bool = false, + noRopeLayerInterval: Int = 4, + noRopeLayers: [Int]? = nil + ) { + self.modelType = modelType + self.hiddenSize = hiddenSize + self.hiddenLayers = hiddenLayers + self.intermediateSize = intermediateSize + self.attentionHeads = attentionHeads + self.headDimensions = headDimensions + self.rmsNormEps = rmsNormEps + self.vocabularySize = vocabularySize + self.kvHeads = kvHeads + self.maxPositionEmbeddings = maxPositionEmbeddings + self.ropeTheta = ropeTheta + self.ropeTraditional = ropeTraditional + self.ropeScaling = ropeScaling + self.tieWordEmbeddings = tieWordEmbeddings + self.attentionBias = attentionBias + self.mlpBias = mlpBias + self.noRopeLayerInterval = noRopeLayerInterval + + if let noRopeLayers = noRopeLayers { + self.noRopeLayers = noRopeLayers + } else { + self.noRopeLayers = (0.. LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} + From 4652c26c3757f040e35ca1afd3be1f9f6517ab66 Mon Sep 17 00:00:00 2001 From: John Mai Date: Fri, 11 Jul 2025 01:17:05 +0800 Subject: [PATCH 2/2] feat: Add SmolLM3 model configuration --- Libraries/MLXLLM/LLMModelFactory.swift | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index 820da8d7..a181bb0b 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -231,6 +231,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { id: "mlx-community/bitnet-b1.58-2B-4T-4bit", defaultPrompt: "Why is the sky blue?" ) + + static public let smollm3_3b_4bit = ModelConfiguration( + id: "mlx-community/SmolLM3-3B-4bit", + defaultPrompt: "Why is the sky blue?" + ) private static func all() -> [ModelConfiguration] { [ @@ -264,6 +269,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { glm4_9b_4bit, acereason_7b_4bit, bitnet_b1_58_2b_4t_4bit, + smollm3_3b_4bit, ] }