Skip to content

Commit fad4ad0

Browse files
Add Phi 3.5 MoE (#116)
* Add Phi 3.5 MoE * make sure all models are registered. fix prompt generation * make SwitchLinear match how Linear work re: quantization * split switch layers into its own file (to better match the python version) --------- Co-authored-by: David Koski <dkoski@apple.com>
1 parent 5a7a1a4 commit fad4ad0

File tree

7 files changed

+479
-27
lines changed

7 files changed

+479
-27
lines changed

Libraries/LLM/Configuration.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ private class ModelTypeRegistry: @unchecked Sendable {
5353
Phi3Configuration.self, from: Data(contentsOf: url))
5454
return Phi3Model(configuration)
5555
},
56+
"phimoe": { url in
57+
let configuration = try JSONDecoder().decode(
58+
PhiMoEConfiguration.self, from: Data(contentsOf: url))
59+
return PhiMoEModel(configuration)
60+
},
5661
"gemma": { url in
5762
let configuration = try JSONDecoder().decode(
5863
GemmaConfiguration.self, from: Data(contentsOf: url))

Libraries/LLM/Models.swift

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ extension ModelConfiguration {
134134
extraEOSTokens: ["<|end|>"]
135135
)
136136

137+
public static let phi3_5MoE = ModelConfiguration(
138+
id: "mlx-community/Phi-3.5-MoE-instruct-4bit",
139+
defaultPrompt: "What is the gravity on Mars and the moon?",
140+
extraEOSTokens: ["<|end|>"]
141+
) {
142+
prompt in
143+
"<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
144+
}
145+
137146
public static let gemma2bQuantized = ModelConfiguration(
138147
id: "mlx-community/quantized-gemma-2b-it",
139148
overrideTokenizer: "PreTrainedTokenizer",
@@ -202,19 +211,22 @@ extension ModelConfiguration {
202211
case .idle:
203212
bootstrapState = .bootstrapping
204213
register(configurations: [
214+
codeLlama13b4bit,
215+
gemma2bQuantized,
216+
gemma_2_2b_it_4bit,
217+
gemma_2_9b_it_4bit,
205218
llama3_1_8B_4bit,
206219
llama3_2_1B_4bit,
207220
llama3_2_3B_4bit,
208-
mistralNeMo4bit,
209-
smolLM_135M_4bit,
221+
llama3_8B_4bit,
210222
mistral7B4bit,
211-
codeLlama13b4bit,
212-
phi4bit,
223+
mistralNeMo4bit,
224+
openelm270m4bit,
225+
phi3_5MoE,
213226
phi3_5_4bit,
214-
gemma2bQuantized,
215-
gemma_2_9b_it_4bit,
227+
phi4bit,
216228
qwen205b4bit,
217-
openelm270m4bit,
229+
smolLM_135M_4bit,
218230
])
219231
bootstrapState = .bootstrapped
220232

Libraries/LLM/Phi3.swift

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,21 +207,25 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
207207
}
208208
}
209209

210-
public struct Phi3Configuration: Codable, Sendable {
211-
struct RopeScaling: Codable {
212-
let longFactor: [Float]?
213-
let shortFactor: [Float]?
214-
let factor: Float?
215-
let type: String?
216-
217-
enum CodingKeys: String, CodingKey {
218-
case type
219-
case factor
220-
case longFactor = "long_factor"
221-
case shortFactor = "short_factor"
222-
}
210+
struct RopeScalingWithFactorArrays: Codable {
211+
let longFactor: [Float]?
212+
let shortFactor: [Float]?
213+
let factor: Float?
214+
let type: String?
215+
let longMScale: Float?
216+
let shortMScale: Float?
217+
218+
enum CodingKeys: String, CodingKey {
219+
case type
220+
case factor
221+
case longFactor = "long_factor"
222+
case shortFactor = "short_factor"
223+
case longMScale = "long_mscale"
224+
case shortMScale = "short_mscale"
223225
}
226+
}
224227

228+
public struct Phi3Configuration: Codable, Sendable {
225229
var hiddenSize: Int
226230
var hiddenLayers: Int
227231
var intermediateSize: Int
@@ -231,7 +235,7 @@ public struct Phi3Configuration: Codable, Sendable {
231235
var kvHeads: Int
232236
var ropeTheta: Float = 10_000
233237
var ropeTraditional: Bool = false
234-
var ropeScaling: RopeScaling?
238+
var ropeScaling: RopeScalingWithFactorArrays?
235239
var maxPositionEmbeddings: Int
236240
var originalMaxPositionEmbeddings: Int
237241

@@ -273,7 +277,8 @@ public struct Phi3Configuration: Codable, Sendable {
273277
ropeTraditional =
274278
try container.decodeIfPresent(
275279
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
276-
ropeScaling = try container.decodeIfPresent(RopeScaling.self, forKey: .ropeScaling)
280+
ropeScaling = try container.decodeIfPresent(
281+
RopeScalingWithFactorArrays.self, forKey: .ropeScaling)
277282
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
278283
originalMaxPositionEmbeddings = try container.decode(
279284
Int.self, forKey: .originalMaxPositionEmbeddings)

Libraries/LLM/PhiMoE.swift

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import Foundation
2+
import MLX
3+
import MLXFast
4+
import MLXNN
5+
import MLXRandom
6+
7+
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phimoe.py
8+
9+
public struct PhiMoEConfiguration: Codable, Sendable {
10+
var modelType: String = "phimoe"
11+
var vocabularySize: Int = 32064
12+
var hiddenSize: Int = 4096
13+
var intermediateSize: Int = 6400
14+
var hiddenLayers: Int = 32
15+
var attentionHeads: Int = 32
16+
var kvHeads: Int = 8
17+
var maxPositionEmbeddings: Int = 131072
18+
var originalMaxPositionEmbeddings: Int = 4096
19+
var rmsNormEps: Float = 1e-6
20+
var ropeScaling: RopeScalingWithFactorArrays?
21+
var numLocalExperts: Int = 16
22+
var numExpertsPerToken: Int = 2
23+
var ropeTheta: Float = 10000.0
24+
25+
enum CodingKeys: String, CodingKey {
26+
case modelType = "model_type"
27+
case vocabularySize = "vocab_size"
28+
case hiddenSize = "hidden_size"
29+
case intermediateSize = "intermediate_size"
30+
case hiddenLayers = "num_hidden_layers"
31+
case attentionHeads = "num_attention_heads"
32+
case kvHeads = "num_key_value_heads"
33+
case maxPositionEmbeddings = "max_position_embeddings"
34+
case originalMaxPositionEmbeddings = "original_max_position_embeddings"
35+
case rmsNormEps = "rms_norm_eps"
36+
case ropeScaling = "rope_scaling"
37+
case numLocalExperts = "num_local_experts"
38+
case numExpertsPerToken = "num_experts_per_tok"
39+
case ropeTheta = "rope_theta"
40+
}
41+
}
42+
43+
private class Attention: Module {
44+
let args: PhiMoEConfiguration
45+
let scale: Float
46+
47+
@ModuleInfo(key: "q_proj") var wq: Linear
48+
@ModuleInfo(key: "k_proj") var wk: Linear
49+
@ModuleInfo(key: "v_proj") var wv: Linear
50+
@ModuleInfo(key: "o_proj") var wo: Linear
51+
52+
let rope: SuScaledRotaryEmbedding
53+
54+
init(_ args: PhiMoEConfiguration) {
55+
self.args = args
56+
57+
let dim = args.hiddenSize
58+
let heads = args.attentionHeads
59+
let kvHeads = args.kvHeads
60+
61+
let headDim = args.hiddenSize / heads
62+
self.scale = pow(Float(headDim), -0.5)
63+
64+
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
65+
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
66+
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
67+
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: true)
68+
69+
self.rope = SuScaledRotaryEmbedding(
70+
dimensions: headDim,
71+
base: args.ropeTheta,
72+
maxPositionEmbeddings: args.maxPositionEmbeddings,
73+
originalMaxPositionEmbeddings: args.originalMaxPositionEmbeddings,
74+
longFactor: args.ropeScaling?.longFactor as? [Float] ?? [1.0],
75+
longMScale: args.ropeScaling?.longMScale as? Float
76+
)
77+
}
78+
79+
func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
80+
let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2))
81+
82+
let queries = wq(x)
83+
let keys = wk(x)
84+
let values = wv(x)
85+
86+
// Prepare the queries, keys and values for the attention computation
87+
var q = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
88+
var k = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
89+
var v = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
90+
91+
if let cache {
92+
q = rope(q, offset: cache.offset)
93+
k = rope(k, offset: cache.offset)
94+
(k, v) = cache.update(keys: k, values: v)
95+
} else {
96+
q = rope(q)
97+
k = rope(k)
98+
}
99+
100+
let output = MLXFast.scaledDotProductAttention(
101+
queries: q, keys: k, values: v, scale: scale, mask: mask
102+
)
103+
.transposed(0, 2, 1, 3)
104+
.reshaped(B, L, -1)
105+
106+
return wo(output)
107+
}
108+
}
109+
110+
private class PhiMoESparseMoeBlock: Module {
111+
let hiddenDim: Int
112+
let ffnDim: Int
113+
let numExperts: Int
114+
let topK: Int
115+
116+
@ModuleInfo(key: "gate") var gate: Linear
117+
@ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU
118+
119+
init(_ args: PhiMoEConfiguration) {
120+
self.hiddenDim = args.hiddenSize
121+
self.ffnDim = args.intermediateSize
122+
self.numExperts = args.numLocalExperts
123+
self.topK = args.numExpertsPerToken
124+
125+
self._gate.wrappedValue = Linear(hiddenDim, numExperts, bias: false)
126+
self._switchMLP.wrappedValue = SwitchGLU(
127+
inputDims: hiddenDim, hiddenDims: ffnDim, numExperts: numExperts)
128+
}
129+
130+
func callAsFunction(_ x: MLXArray) -> MLXArray {
131+
let gates = gate(x)
132+
133+
let k = self.topK
134+
let inds = MLX.stopGradient(
135+
MLX.argPartition(
136+
-gates,
137+
kth: k - 1,
138+
axis: -1
139+
)[.ellipsis, ..<k])
140+
let scores = MLX.softmax(MLX.takeAlong(gates, inds, axis: -1), axis: -1, precise: true)
141+
142+
let y = switchMLP(x, inds)
143+
return (y * scores[.ellipsis, .newAxis]).sum(axis: -2)
144+
}
145+
}
146+
147+
private class PhiMoEDecoderLayer: Module {
148+
let hiddenSize: Int
149+
150+
@ModuleInfo(key: "self_attn") var selfAttn: Attention
151+
@ModuleInfo(key: "block_sparse_moe") var blockSparseMoe: PhiMoESparseMoeBlock
152+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
153+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: LayerNorm
154+
155+
init(_ args: PhiMoEConfiguration) {
156+
self.hiddenSize = args.hiddenSize
157+
158+
self._selfAttn.wrappedValue = Attention(args)
159+
self._blockSparseMoe.wrappedValue = PhiMoESparseMoeBlock(args)
160+
self._inputLayerNorm.wrappedValue = LayerNorm(
161+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
162+
self._postAttentionLayerNorm.wrappedValue = LayerNorm(
163+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
164+
}
165+
166+
func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
167+
var residual = x
168+
var hiddenStates = inputLayerNorm(x)
169+
hiddenStates = selfAttn(hiddenStates, mask: mask, cache: cache)
170+
hiddenStates = residual + hiddenStates
171+
172+
residual = hiddenStates
173+
hiddenStates = postAttentionLayerNorm(hiddenStates)
174+
hiddenStates = blockSparseMoe(hiddenStates)
175+
hiddenStates = residual + hiddenStates
176+
177+
return hiddenStates
178+
}
179+
}
180+
181+
private class PhiMoEModelInner: Module {
182+
let args: PhiMoEConfiguration
183+
184+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
185+
let layers: [PhiMoEDecoderLayer]
186+
@ModuleInfo(key: "norm") var norm: LayerNorm
187+
188+
init(_ args: PhiMoEConfiguration) {
189+
self.args = args
190+
191+
self._embedTokens.wrappedValue = Embedding(
192+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
193+
self.layers = (0 ..< args.hiddenLayers).map { _ in PhiMoEDecoderLayer(args) }
194+
self._norm.wrappedValue = LayerNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
195+
}
196+
197+
func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
198+
var h = embedTokens(inputs)
199+
200+
let mask = createAttentionMask(h: h, cache: cache)
201+
202+
for (i, layer) in layers.enumerated() {
203+
h = layer(h, mask: mask, cache: cache?[i])
204+
}
205+
206+
return norm(h)
207+
}
208+
}
209+
210+
public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider {
211+
public let vocabularySize: Int
212+
public let kvHeads: [Int]
213+
public let headDim: IntOrPair
214+
215+
fileprivate let model: PhiMoEModelInner
216+
@ModuleInfo(key: "lm_head") var lmHead: Linear
217+
218+
public init(_ args: PhiMoEConfiguration) {
219+
self.vocabularySize = args.vocabularySize
220+
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
221+
self.headDim = .init(args.hiddenSize / args.attentionHeads)
222+
self.model = PhiMoEModelInner(args)
223+
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
224+
}
225+
226+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
227+
let out = model(inputs, cache: cache)
228+
return lmHead(out)
229+
}
230+
231+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
232+
var sanitizedWeights = weights
233+
if sanitizedWeights["model.layers.0.block_sparse_moe.experts.0.w1.weight"] == nil {
234+
return sanitizedWeights
235+
}
236+
237+
for l in 0 ..< model.args.hiddenLayers {
238+
let prefix = "model.layers.\(l)"
239+
for (n, m) in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")] {
240+
for k in ["weight", "scales", "biases"] {
241+
if sanitizedWeights["\(prefix).block_sparse_moe.experts.0.\(n).\(k)"] != nil {
242+
let toJoin = (0 ..< model.args.numLocalExperts).map { e in
243+
sanitizedWeights.removeValue(
244+
forKey: "\(prefix).block_sparse_moe.experts.\(e).\(n).\(k)")!
245+
}
246+
sanitizedWeights["\(prefix).block_sparse_moe.switch_mlp.\(m).\(k)"] =
247+
MLX.stacked(toJoin)
248+
}
249+
}
250+
}
251+
}
252+
253+
return sanitizedWeights
254+
}
255+
}
256+
257+
// MARK: - LoRA
258+
259+
extension PhiMoEModel: LoRAModel {
260+
public func loraLinearLayers() -> LoRALinearLayers {
261+
model.layers.map { ($0.selfAttn, ["q_proj", "v_proj"]) }
262+
}
263+
}

0 commit comments

Comments
 (0)