Skip to content

Add SmolLM3 #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
"glm4": create(GLM4Configuration.self, GLM4Model.init),
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
]
}

Expand Down
370 changes: 370 additions & 0 deletions Libraries/MLXLLM/Models/SmolLM3.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
//
// SmolLM3.swift
// mlx-swift-examples
//
// Created by John Mai on 2025/7/4.
//

import Foundation
import MLX
import MLXLMCommon
import MLXNN

private protocol PositionEmbedding {
func callAsFunction(_ x: MLXArray, offset: Int) -> MLXArray
func callAsFunction(_ x: MLXArray) -> MLXArray
}

extension RoPE: PositionEmbedding {}

// MARK: - NoPE

private final class NoPE: Module, PositionEmbedding {
func callAsFunction(_ x: MLXArray, offset: Int) -> MLXArray {
return x
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
callAsFunction(x, offset: 0)
}
}

// MARK: - Attention

private class Attention: Module {
let args: SmolLM3Configuration
let scale: Float

@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear

var rope: PositionEmbedding

init(_ args: SmolLM3Configuration) {
self.args = args

let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads

let headDim = args.resolvedHeadDimensions
self.scale = pow(Float(headDim), -0.5)

self.rope = RoPE(
dimensions: headDim,
traditional: args.ropeTraditional,
base: args.ropeTheta,
scale: 1.0
)

_wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
_wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
}

func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
let (B, L) = (x.dim(0), x.dim(1))

var queries = wq(x)
var keys = wk(x)
var values = wv(x)

queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries)
keys = rope(keys)
}

let output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)

return wo(output)
}
}

// MARK: - MLP

private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear

init(_ args: SmolLM3Configuration) {
_gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
_down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
_up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
let activation = silu(gate(x))
return down(activation * up(x))
}
}

private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
@ModuleInfo(key: "mlp") var mlp: MLP

@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm

init(_ args: SmolLM3Configuration) {
_attention.wrappedValue = Attention(args)
_mlp.wrappedValue = MLP(args)
_inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
_postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return out
}
}

// MARK: - Model

private class SmolLM3ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding

fileprivate let layers: [TransformerBlock]
let norm: RMSNorm

init(_ args: SmolLM3Configuration) {
precondition(args.vocabularySize > 0)

_embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)

self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
var h = embedTokens(inputs)

let mask = createAttentionMask(h: h, cache: cache)

for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}

return norm(h)
}
}

public class SmolLM3Model: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]

private let model: SmolLM3ModelInner
let configuration: SmolLM3Configuration

@ModuleInfo(key: "lm_head") var lmHead: Linear?

public init(_ args: SmolLM3Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }

self.model = SmolLM3ModelInner(args)

if !args.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}

super.init()

let identityRope = NoPE()
for (idx, useRope) in args.noRopeLayers.enumerated() {
if useRope == 0 && idx < model.layers.count {
model.layers[idx].attention.rope = identityRope
}
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var out = model(inputs, cache: cache)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return out
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights

weights = weights.filter { key, _ in
!key.contains("self_attn.rotary_emb.inv_freq")
}

if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}

return weights
}
}

// MARK: - Configuration

public struct SmolLM3Configuration: Codable, Sendable {
var modelType: String
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var headDimensions: Int?
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var maxPositionEmbeddings: Int?
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]?
var tieWordEmbeddings: Bool = true
var attentionBias: Bool = false
var mlpBias: Bool = false

var noRopeLayerInterval: Int = 4
var noRopeLayers: [Int] = []

var resolvedHeadDimensions: Int {
headDimensions ?? (hiddenSize / attentionHeads)
}

public init(
modelType: String = "smollm3",
hiddenSize: Int,
hiddenLayers: Int,
intermediateSize: Int,
attentionHeads: Int,
headDimensions: Int? = nil,
rmsNormEps: Float,
vocabularySize: Int,
kvHeads: Int,
maxPositionEmbeddings: Int? = nil,
ropeTheta: Float = 10_000,
ropeTraditional: Bool = false,
ropeScaling: [String: StringOrNumber]? = nil,
tieWordEmbeddings: Bool = true,
attentionBias: Bool = false,
mlpBias: Bool = false,
noRopeLayerInterval: Int = 4,
noRopeLayers: [Int]? = nil
) {
self.modelType = modelType
self.hiddenSize = hiddenSize
self.hiddenLayers = hiddenLayers
self.intermediateSize = intermediateSize
self.attentionHeads = attentionHeads
self.headDimensions = headDimensions
self.rmsNormEps = rmsNormEps
self.vocabularySize = vocabularySize
self.kvHeads = kvHeads
self.maxPositionEmbeddings = maxPositionEmbeddings
self.ropeTheta = ropeTheta
self.ropeTraditional = ropeTraditional
self.ropeScaling = ropeScaling
self.tieWordEmbeddings = tieWordEmbeddings
self.attentionBias = attentionBias
self.mlpBias = mlpBias
self.noRopeLayerInterval = noRopeLayerInterval

if let noRopeLayers = noRopeLayers {
self.noRopeLayers = noRopeLayers
} else {
self.noRopeLayers = (0..<hiddenLayers).map { i in
(i + 1) % noRopeLayerInterval != 0 ? 1 : 0
}
}
}

enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case headDimensions = "head_dim"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
case attentionBias = "attention_bias"
case mlpBias = "mlp_bias"
case noRopeLayerInterval = "no_rope_layer_interval"
case noRopeLayers = "no_rope_layers"
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

self.modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "smollm3"
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
self.headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
self.kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
self.maxPositionEmbeddings = try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings)
self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10_000
self.ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent([String: StringOrNumber].self, forKey: .ropeScaling)
self.tieWordEmbeddings = try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? true
self.attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false
self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false

self.noRopeLayerInterval = try container.decodeIfPresent(Int.self, forKey: .noRopeLayerInterval) ?? 4

if let noRopeLayers = try container.decodeIfPresent([Int].self, forKey: .noRopeLayers) {
self.noRopeLayers = noRopeLayers
} else {
self.noRopeLayers = (0..<hiddenLayers).map { i in
(i + 1) % noRopeLayerInterval != 0 ? 1 : 0
}
}
}
}

// MARK: - LoRA

extension SmolLM3Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}