Skip to content

Commit 3013e71

Browse files
authored
✨ feat: Add support for Xiaomi MiMo model (#306)
* ✨ feat: Add support for MiMo model
1 parent 45b2956 commit 3013e71

File tree

4 files changed

+282
-2
lines changed

4 files changed

+282
-2
lines changed

Libraries/MLXLLM/Documentation.docc/Documentation.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,5 @@ Example implementations of various Large Language Models (LLMs).
3030
- ``Qwen2Model``
3131
- ``Qwen3Model``
3232
- ``Starcoder2Model``
33+
- ``MiMoModel``
3334
- ``GLM4Model``
34-
35-

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4343
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
4444
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
4545
"granite": create(GraniteConfiguration.self, GraniteModel.init),
46+
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
4647
"glm4": create(GLM4Configuration.self, GLM4Model.init),
4748
]
4849
}
@@ -200,6 +201,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
200201
defaultPrompt: ""
201202
)
202203

204+
static public let mimo_7b_sft_4bit = ModelConfiguration(
205+
id: "mlx-community/MiMo-7B-SFT-4bit",
206+
defaultPrompt: "Why is the sky blue?"
207+
)
208+
203209
static public let glm4_9b_4bit = ModelConfiguration(
204210
id: "mlx-community/GLM-4-9B-0414-4bit",
205211
defaultPrompt: "Why is the sky blue?"
@@ -231,6 +237,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
231237
qwen3_4b_4bit,
232238
qwen3_8b_4bit,
233239
smolLM_135M_4bit,
240+
mimo_7b_sft_4bit,
234241
glm4_9b_4bit,
235242
]
236243
}

Libraries/MLXLLM/Models/MiMo.swift

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
//
2+
// MiMo.swift
3+
// LLM
4+
//
5+
// Created by John Mai on 2025/5/3.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXLMCommon
12+
import MLXNN
13+
14+
private class Attention: Module {
15+
let args: MiMoConfiguration
16+
let scale: Float
17+
18+
@ModuleInfo(key: "q_proj") var wq: Linear
19+
@ModuleInfo(key: "k_proj") var wk: Linear
20+
@ModuleInfo(key: "v_proj") var wv: Linear
21+
@ModuleInfo(key: "o_proj") var wo: Linear
22+
23+
let rope: RoPE
24+
25+
public init(_ args: MiMoConfiguration) {
26+
self.args = args
27+
28+
let dim = args.hiddenSize
29+
let heads = args.attentionHeads
30+
let kvHeads = args.kvHeads
31+
32+
let headDim = args.hiddenSize / heads
33+
self.scale = pow(Float(headDim), -0.5)
34+
35+
_wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
36+
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
37+
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
38+
_wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
39+
40+
let ropeScale: Float
41+
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
42+
let factor = ropeScaling["factor"]
43+
{
44+
if let v = factor.asFloat() {
45+
ropeScale = 1 / v
46+
} else {
47+
fatalError("ropeScaling.factor must be a float")
48+
}
49+
} else {
50+
ropeScale = 1
51+
}
52+
53+
self.rope = RoPE(
54+
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
55+
scale: ropeScale)
56+
}
57+
58+
public func callAsFunction(
59+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
60+
) -> MLXArray {
61+
let (B, L) = (x.dim(0), x.dim(1))
62+
63+
var queries = wq(x)
64+
var keys = wk(x)
65+
var values = wv(x)
66+
67+
// prepare the queries, keys and values for the attention computation
68+
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
69+
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
70+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
71+
72+
if let cache {
73+
queries = rope(queries, offset: cache.offset)
74+
keys = rope(keys, offset: cache.offset)
75+
(keys, values) = cache.update(keys: keys, values: values)
76+
} else {
77+
queries = rope(queries)
78+
keys = rope(keys)
79+
}
80+
81+
let output = MLXFast.scaledDotProductAttention(
82+
queries: queries, keys: keys, values: values, scale: scale, mask: mask
83+
)
84+
.transposed(0, 2, 1, 3)
85+
.reshaped(B, L, -1)
86+
87+
return wo(output)
88+
}
89+
}
90+
91+
private class MLP: Module, UnaryLayer {
92+
@ModuleInfo(key: "gate_proj") var gate: Linear
93+
@ModuleInfo(key: "down_proj") var down: Linear
94+
@ModuleInfo(key: "up_proj") var up: Linear
95+
96+
public init(dimensions: Int, hiddenDimensions: Int) {
97+
_gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
98+
_down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
99+
_up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
100+
}
101+
102+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
103+
down(silu(gate(x)) * up(x))
104+
}
105+
}
106+
107+
private class TransformerBlock: Module {
108+
@ModuleInfo(key: "self_attn") var attention: Attention
109+
let mlp: MLP
110+
111+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
112+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
113+
114+
public init(_ args: MiMoConfiguration) {
115+
_attention.wrappedValue = Attention(args)
116+
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
117+
_inputLayerNorm.wrappedValue = RMSNorm(
118+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
119+
_postAttentionLayerNorm.wrappedValue = RMSNorm(
120+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
121+
}
122+
123+
public func callAsFunction(
124+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
125+
) -> MLXArray {
126+
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
127+
let h = x + r
128+
r = mlp(postAttentionLayerNorm(h))
129+
let out = h + r
130+
return out
131+
}
132+
}
133+
134+
private class MiMoModelInner: Module {
135+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
136+
137+
fileprivate let layers: [TransformerBlock]
138+
let norm: RMSNorm
139+
140+
let numNextnPredictLayers: Int
141+
142+
public init(_ args: MiMoConfiguration) {
143+
precondition(args.vocabularySize > 0)
144+
145+
_embedTokens.wrappedValue = Embedding(
146+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
147+
148+
self.layers = (0 ..< args.hiddenLayers).map { _ in
149+
TransformerBlock(args)
150+
}
151+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
152+
self.numNextnPredictLayers = args.numNextnPredictLayers
153+
}
154+
155+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
156+
var h = embedTokens(inputs)
157+
158+
let mask = createAttentionMask(h: h, cache: cache)
159+
160+
for (i, layer) in layers.enumerated() {
161+
h = layer(h, mask: mask, cache: cache?[i])
162+
}
163+
164+
return norm(h)
165+
}
166+
}
167+
168+
public class MiMoModel: Module, LLMModel, KVCacheDimensionProvider {
169+
public let vocabularySize: Int
170+
public let kvHeads: [Int]
171+
172+
private let model: MiMoModelInner
173+
let configuration: MiMoConfiguration
174+
175+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
176+
177+
public init(_ args: MiMoConfiguration) {
178+
self.configuration = args
179+
self.vocabularySize = args.vocabularySize
180+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
181+
self.model = MiMoModelInner(args)
182+
183+
if !args.tieWordEmbeddings {
184+
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
185+
}
186+
}
187+
188+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
189+
let out = model(inputs, cache: cache)
190+
191+
if let lmHead = lmHead {
192+
return lmHead(out)
193+
} else {
194+
return model.embedTokens.asLinear(out)
195+
}
196+
}
197+
198+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
199+
var weights = weights
200+
201+
if configuration.tieWordEmbeddings {
202+
weights.removeValue(forKey: "lm_head.weight")
203+
}
204+
205+
// Remove unused precomputed rotary freqs and mtp_layers
206+
return weights.filter { key, _ in
207+
!key.contains("self_attn.rotary_emb.inv_freq") && !key.hasPrefix("model.mtp_layers.")
208+
}
209+
}
210+
}
211+
212+
public struct MiMoConfiguration: Codable, Sendable {
213+
var hiddenSize: Int
214+
var hiddenLayers: Int
215+
var intermediateSize: Int
216+
var attentionHeads: Int
217+
var rmsNormEps: Float
218+
var vocabularySize: Int
219+
var kvHeads: Int
220+
var maxPositionEmbeddings: Int
221+
var ropeTheta: Float
222+
var ropeTraditional: Bool
223+
var ropeScaling: [String: StringOrNumber]?
224+
var tieWordEmbeddings: Bool
225+
var numNextnPredictLayers: Int
226+
227+
enum CodingKeys: String, CodingKey {
228+
case hiddenSize = "hidden_size"
229+
case hiddenLayers = "num_hidden_layers"
230+
case intermediateSize = "intermediate_size"
231+
case attentionHeads = "num_attention_heads"
232+
case rmsNormEps = "rms_norm_eps"
233+
case vocabularySize = "vocab_size"
234+
case kvHeads = "num_key_value_heads"
235+
case maxPositionEmbeddings = "max_position_embeddings"
236+
case ropeTheta = "rope_theta"
237+
case ropeTraditional = "rope_traditional"
238+
case ropeScaling = "rope_scaling"
239+
case tieWordEmbeddings = "tie_word_embeddings"
240+
case numNextnPredictLayers = "num_nextn_predict_layers"
241+
}
242+
243+
public init(from decoder: Decoder) throws {
244+
let container = try decoder.container(keyedBy: CodingKeys.self)
245+
246+
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
247+
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
248+
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
249+
self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
250+
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
251+
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
252+
self.kvHeads = try container.decode(Int.self, forKey: .kvHeads)
253+
self.maxPositionEmbeddings =
254+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
255+
self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000.0
256+
self.ropeTraditional =
257+
try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false
258+
self.ropeScaling = try container.decodeIfPresent(
259+
[String: StringOrNumber].self, forKey: .ropeScaling)
260+
self.tieWordEmbeddings =
261+
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
262+
self.numNextnPredictLayers =
263+
try container.decodeIfPresent(Int.self, forKey: .numNextnPredictLayers) ?? 2
264+
}
265+
}
266+
267+
// MARK: - LoRA
268+
269+
extension MiMoModel: LoRAModel {
270+
public func loraLinearLayers() -> LoRALinearLayers {
271+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
272+
}
273+
}

Libraries/MLXLLM/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently supported model types are:
5656
- Qwen2
5757
- Qwen3
5858
- Starcoder2
59+
- MiMo
5960
- GLM4
6061

6162
See [llm-tool](../../Tools/llm-tool)

0 commit comments

Comments
 (0)