Skip to content

Commit bccfd7a

Browse files
anishbasuAnish Basu
andauthored
Add Embedders/Encoders support. (#157)
* Add Embedders/Encoders support. * Fixing spelling error in "Tropical". * Running pre-commit hooks for the PR. * Remove extra header docs from the Tokenizer.swift file. --------- Co-authored-by: Anish Basu <banish@apple.com>
1 parent a2ebad7 commit bccfd7a

File tree

11 files changed

+1591
-6
lines changed

11 files changed

+1591
-6
lines changed

Libraries/Embedders/Bert.swift

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import MLX
4+
import MLXFast
5+
import MLXNN
6+
7+
extension MLXArray {
8+
public static func arange(_ size: Int) -> MLXArray {
9+
return MLXArray(Array(0 ..< size))
10+
}
11+
}
12+
13+
private class BertEmbedding: Module {
14+
15+
let typeVocabularySize: Int
16+
@ModuleInfo(key: "word_embeddings") var wordEmbeddings: Embedding
17+
@ModuleInfo(key: "norm") var norm: LayerNorm
18+
@ModuleInfo(key: "token_type_embeddings") var tokenTypeEmbeddings: Embedding?
19+
@ModuleInfo(key: "position_embeddings") var positionEmbeddings: Embedding
20+
21+
init(_ config: BertConfiguration) {
22+
typeVocabularySize = config.typeVocabularySize
23+
_wordEmbeddings.wrappedValue = Embedding(
24+
embeddingCount: config.vocabularySize, dimensions: config.embedDim)
25+
_norm.wrappedValue = LayerNorm(
26+
dimensions: config.embedDim, eps: config.layerNormEps)
27+
if config.typeVocabularySize > 0 {
28+
_tokenTypeEmbeddings.wrappedValue = Embedding(
29+
embeddingCount: config.typeVocabularySize,
30+
dimensions: config.embedDim)
31+
}
32+
_positionEmbeddings.wrappedValue = Embedding(
33+
embeddingCount: config.maxPositionEmbeddings,
34+
dimensions: config.embedDim)
35+
36+
}
37+
38+
func callAsFunction(
39+
_ inputIds: MLXArray,
40+
positionIds: MLXArray? = nil,
41+
tokenTypeIds: MLXArray? = nil
42+
) -> MLXArray {
43+
let posIds = positionIds ?? broadcast(MLXArray.arange(inputIds.dim(1)), to: inputIds.shape)
44+
let words = wordEmbeddings(inputIds) + positionEmbeddings(posIds)
45+
if let tokenTypeIds, let tokenTypeEmbeddings {
46+
words += tokenTypeEmbeddings(tokenTypeIds)
47+
}
48+
return norm(words)
49+
}
50+
}
51+
52+
private class TransformerBlock: Module {
53+
let attention: MultiHeadAttention
54+
@ModuleInfo(key: "ln1") var preLayerNorm: LayerNorm
55+
@ModuleInfo(key: "ln2") var postLayerNorm: LayerNorm
56+
@ModuleInfo(key: "linear1") var up: Linear
57+
@ModuleInfo(key: "linear2") var down: Linear
58+
59+
init(_ config: BertConfiguration) {
60+
attention = MultiHeadAttention(
61+
dimensions: config.embedDim, numHeads: config.numHeads, bias: true)
62+
_preLayerNorm.wrappedValue = LayerNorm(
63+
dimensions: config.embedDim, eps: config.layerNormEps)
64+
_postLayerNorm.wrappedValue = LayerNorm(
65+
dimensions: config.embedDim, eps: config.layerNormEps)
66+
_up.wrappedValue = Linear(config.embedDim, config.interDim)
67+
_down.wrappedValue = Linear(config.interDim, config.embedDim)
68+
}
69+
70+
func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil) -> MLXArray {
71+
let attentionOut = attention(inputs, keys: inputs, values: inputs, mask: mask)
72+
let preNorm = preLayerNorm(inputs + attentionOut)
73+
74+
let mlpOut = down(gelu(up(preNorm)))
75+
return postLayerNorm(mlpOut + preNorm)
76+
}
77+
}
78+
79+
private class Encoder: Module {
80+
let layers: [TransformerBlock]
81+
init(_ config: BertConfiguration) {
82+
precondition(config.vocabularySize > 0)
83+
layers = (0 ..< config.numLayers).map { _ in TransformerBlock(config) }
84+
}
85+
func callAsFunction(_ inputs: MLXArray, attentionMask: MLXArray? = nil) -> MLXArray {
86+
var outputs = inputs
87+
for layer in layers {
88+
outputs = layer(outputs, mask: attentionMask)
89+
}
90+
return outputs
91+
}
92+
}
93+
94+
private class LMHead: Module {
95+
@ModuleInfo(key: "dense") var dense: Linear
96+
@ModuleInfo(key: "ln") var layerNorm: LayerNorm
97+
@ModuleInfo(key: "decoder") var decoder: Linear
98+
99+
init(_ config: BertConfiguration) {
100+
_dense.wrappedValue = Linear(
101+
config.embedDim, config.embedDim, bias: true)
102+
_layerNorm.wrappedValue = LayerNorm(
103+
dimensions: config.embedDim, eps: config.layerNormEps)
104+
_decoder.wrappedValue = Linear(
105+
config.embedDim, config.vocabularySize, bias: true)
106+
}
107+
func callAsFunction(_ inputs: MLXArray) -> MLXArray {
108+
return decoder(layerNorm(silu(dense(inputs))))
109+
}
110+
}
111+
112+
public class BertModel: Module, EmbeddingModel {
113+
@ModuleInfo(key: "lm_head") fileprivate var lmHead: LMHead?
114+
@ModuleInfo(key: "embeddings") fileprivate var embedder: BertEmbedding
115+
let pooler: Linear?
116+
fileprivate let encoder: Encoder
117+
public var vocabularySize: Int
118+
119+
public init(
120+
_ config: BertConfiguration, lmHead: Bool = false
121+
) {
122+
precondition(config.vocabularySize > 0)
123+
vocabularySize = config.vocabularySize
124+
encoder = Encoder(config)
125+
_embedder.wrappedValue = BertEmbedding(config)
126+
127+
if lmHead {
128+
_lmHead.wrappedValue = LMHead(config)
129+
self.pooler = nil
130+
} else {
131+
pooler = Linear(config.embedDim, config.embedDim)
132+
_lmHead.wrappedValue = nil
133+
}
134+
}
135+
136+
public func callAsFunction(
137+
_ inputs: MLXArray, positionIds: MLXArray? = nil, tokenTypeIds: MLXArray? = nil,
138+
attentionMask: MLXArray? = nil
139+
)
140+
-> EmbeddingModelOutput
141+
{
142+
var inp = inputs
143+
if inp.ndim == 1 {
144+
inp = inp.reshaped(1, -1)
145+
}
146+
var mask = attentionMask
147+
if mask != nil {
148+
mask = mask!.asType(embedder.wordEmbeddings.weight.dtype).expandedDimensions(axes: [
149+
1, 2,
150+
]).log()
151+
}
152+
let outputs = encoder(
153+
embedder(inp, positionIds: positionIds, tokenTypeIds: tokenTypeIds),
154+
attentionMask: mask)
155+
if let lmHead {
156+
return EmbeddingModelOutput(hiddenStates: lmHead(outputs), pooledOutput: nil)
157+
} else {
158+
return EmbeddingModelOutput(
159+
hiddenStates: outputs, pooledOutput: tanh(pooler!(outputs[0..., 0])))
160+
}
161+
}
162+
163+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
164+
weights.reduce(into: [:]) { result, item in
165+
var key = item.key.replacingOccurrences(of: ".layer.", with: ".layers.")
166+
key = key.replacingOccurrences(of: ".self.key.", with: ".key_proj.")
167+
key = key.replacingOccurrences(of: ".self.query.", with: ".query_proj.")
168+
key = key.replacingOccurrences(of: ".self.value.", with: ".value_proj.")
169+
key = key.replacingOccurrences(
170+
of: ".attention.output.dense.", with: ".attention.out_proj.")
171+
key = key.replacingOccurrences(of: ".attention.output.LayerNorm.", with: ".ln1.")
172+
key = key.replacingOccurrences(of: ".output.LayerNorm.", with: ".ln2.")
173+
key = key.replacingOccurrences(of: ".intermediate.dense.", with: ".linear1.")
174+
key = key.replacingOccurrences(of: ".output.dense.", with: ".linear2.")
175+
key = key.replacingOccurrences(of: ".LayerNorm.", with: ".norm.")
176+
key = key.replacingOccurrences(of: "pooler.dense.", with: "pooler.")
177+
key = key.replacingOccurrences(
178+
of:
179+
"cls.predictions.transform.dense.",
180+
with: "lm_head.dense.")
181+
key = key.replacingOccurrences(
182+
of:
183+
"cls.predictions.transform.LayerNorm.",
184+
with: "lm_head.ln.")
185+
key = key.replacingOccurrences(
186+
of:
187+
"cls.predictions.decoder",
188+
with: "lm_head.decoder")
189+
key = key.replacingOccurrences(
190+
of: "cls.predictions.transform.norm.weight",
191+
with: "lm_head.ln.weight")
192+
key = key.replacingOccurrences(
193+
of: "cls.predictions.transform.norm.bias",
194+
with: "lm_head.ln.bias")
195+
key = key.replacingOccurrences(of: "cls.predictions.bias", with: "lm_head.decoder.bias")
196+
key = key.replacingOccurrences(of: "bert.", with: "")
197+
result[key] = item.value
198+
}.filter { key, _ in key != "embeddings.position_ids" }
199+
}
200+
}
201+
202+
public class DistilBertModel: BertModel {
203+
public override func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
204+
weights.reduce(into: [:]) { result, item in
205+
var key = item.key.replacingOccurrences(of: ".layer.", with: ".layers.")
206+
key = key.replacingOccurrences(of: "transformer.", with: "encoder.")
207+
key = key.replacingOccurrences(of: "embeddings.LayerNorm", with: "embeddings.norm")
208+
key = key.replacingOccurrences(of: ".attention.q_lin.", with: ".attention.query_proj.")
209+
key = key.replacingOccurrences(of: ".attention.k_lin.", with: ".attention.key_proj.")
210+
key = key.replacingOccurrences(of: ".attention.v_lin.", with: ".attention.value_proj.")
211+
key = key.replacingOccurrences(of: ".attention.out_lin.", with: ".attention.out_proj.")
212+
key = key.replacingOccurrences(of: ".sa_layer_norm.", with: ".ln1.")
213+
key = key.replacingOccurrences(of: ".ffn.lin1.", with: ".linear1.")
214+
key = key.replacingOccurrences(of: ".ffn.lin2.", with: ".linear2.")
215+
key = key.replacingOccurrences(of: ".output_layer_norm.", with: ".ln2.")
216+
key = key.replacingOccurrences(of: "vocab_transform", with: "lm_head.dense")
217+
key = key.replacingOccurrences(of: "vocab_layer_norm", with: "lm_head.ln")
218+
key = key.replacingOccurrences(of: "vocab_projector", with: "lm_head.decoder")
219+
key = key.replacingOccurrences(of: "distilbert.", with: "")
220+
result[key] = item.value
221+
}.filter { key, _ in key != "embeddings.position_ids" }
222+
}
223+
}
224+
225+
public struct BertConfiguration: Decodable, Sendable {
226+
var layerNormEps: Float = 1e-12
227+
var maxTrainedPositions: Int = 2048
228+
var embedDim: Int = 768
229+
var numHeads: Int = 12
230+
var interDim: Int = 3072
231+
var numLayers: Int = 12
232+
var typeVocabularySize: Int = 2
233+
var vocabularySize: Int = 30528
234+
var maxPositionEmbeddings: Int = 0
235+
var modelType: String
236+
237+
enum CodingKeys: String, CodingKey {
238+
case layerNormEps = "layer_norm_eps"
239+
case maxTrainedPositions = "max_trained_positions"
240+
case vocabularySize = "vocab_size"
241+
case maxPositionEmbeddings = "max_position_embeddings"
242+
case modelType = "model_type"
243+
}
244+
245+
enum BertCodingKeys: String, CodingKey {
246+
case embedDim = "hidden_size"
247+
case numHeads = "num_attention_heads"
248+
case interDim = "intermediate_size"
249+
case numLayers = "num_hidden_layers"
250+
case typeVocabularySize = "type_vocab_size"
251+
}
252+
253+
enum DistilBertCodingKeys: String, CodingKey {
254+
case embedDim = "dim"
255+
case numLayers = "n_layers"
256+
case numHeads = "n_heads"
257+
case interDim = "hidden_dim"
258+
}
259+
260+
public init(from decoder: Decoder) throws {
261+
let container: KeyedDecodingContainer<CodingKeys> =
262+
try decoder.container(
263+
keyedBy: CodingKeys.self)
264+
layerNormEps =
265+
try container.decodeIfPresent(
266+
Float.self,
267+
forKey: CodingKeys.layerNormEps.self)
268+
?? 1e-12
269+
maxTrainedPositions =
270+
try container.decodeIfPresent(
271+
Int.self,
272+
forKey: CodingKeys.maxTrainedPositions
273+
.self) ?? 2048
274+
vocabularySize =
275+
try container.decodeIfPresent(
276+
Int.self,
277+
forKey: CodingKeys.vocabularySize.self)
278+
?? 30528
279+
maxPositionEmbeddings =
280+
try container.decodeIfPresent(
281+
Int.self,
282+
forKey: CodingKeys.maxPositionEmbeddings
283+
.self) ?? 0
284+
modelType = try container.decode(String.self, forKey: CodingKeys.modelType.self)
285+
286+
if modelType == "distilbert" {
287+
let distilBertConfig: KeyedDecodingContainer<DistilBertCodingKeys> =
288+
try decoder.container(
289+
keyedBy: DistilBertCodingKeys.self)
290+
embedDim =
291+
try distilBertConfig.decodeIfPresent(
292+
Int.self,
293+
forKey: DistilBertCodingKeys.embedDim.self) ?? 768
294+
numHeads =
295+
try distilBertConfig.decodeIfPresent(
296+
Int.self,
297+
forKey: DistilBertCodingKeys.numHeads.self) ?? 12
298+
interDim =
299+
try distilBertConfig.decodeIfPresent(
300+
Int.self, forKey: DistilBertCodingKeys.interDim.self)
301+
?? 3072
302+
numLayers =
303+
try distilBertConfig.decodeIfPresent(
304+
Int.self,
305+
forKey: DistilBertCodingKeys.numLayers.self) ?? 12
306+
typeVocabularySize = 0
307+
} else {
308+
let bertConfig: KeyedDecodingContainer<BertCodingKeys> = try decoder.container(
309+
keyedBy: BertCodingKeys.self)
310+
311+
embedDim =
312+
try bertConfig.decodeIfPresent(
313+
Int.self,
314+
forKey: BertCodingKeys.embedDim.self) ?? 768
315+
numHeads =
316+
try bertConfig.decodeIfPresent(
317+
Int.self,
318+
forKey: BertCodingKeys.numHeads.self) ?? 12
319+
interDim =
320+
try bertConfig.decodeIfPresent(
321+
Int.self, forKey: BertCodingKeys.interDim.self)
322+
?? 3072
323+
numLayers =
324+
try bertConfig.decodeIfPresent(
325+
Int.self,
326+
forKey: BertCodingKeys.numLayers.self) ?? 12
327+
typeVocabularySize =
328+
try bertConfig.decodeIfPresent(
329+
Int.self,
330+
forKey: BertCodingKeys.typeVocabularySize
331+
.self) ?? 2
332+
}
333+
}
334+
}

0 commit comments

Comments
 (0)