-
Notifications
You must be signed in to change notification settings - Fork 270
Port of Ernie4 5 #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
smdesai
wants to merge
3
commits into
ml-explore:main
Choose a base branch
from
smdesai:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Port of Ernie4 5 #348
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
// | ||
// Ernie4_5.swift | ||
// mlx-swift-examples | ||
// | ||
// Created by Sachin Desai on 7/3/25. | ||
// | ||
|
||
import Foundation | ||
import MLX | ||
import MLXLMCommon | ||
import MLXNN | ||
|
||
// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ernie4_5.py | ||
|
||
public struct Ernie45Configuration: Codable, Sendable { | ||
var hiddenSize: Int | ||
var intermediateSize: Int | ||
var maxPositionEmbeddings: Int | ||
var numAttentionHeads: Int | ||
var numKeyValueHeads: Int | ||
var headDim: Int? | ||
var numHiddenLayers: Int | ||
var rmsNormEps: Float | ||
var vocabularySize: Int | ||
var ropeTheta: Float | ||
var useBias: Bool | ||
var tieWordEmbeddings: Bool | ||
|
||
enum CodingKeys: String, CodingKey { | ||
case hiddenSize = "hidden_size" | ||
case intermediateSize = "intermediate_size" | ||
case maxPositionEmbeddings = "max_position_embeddings" | ||
case numAttentionHeads = "num_attention_heads" | ||
case numKeyValueHeads = "num_key_value_heads" | ||
case headDim = "head_dim" | ||
case numHiddenLayers = "num_hidden_layers" | ||
case rmsNormEps = "rms_norm_eps" | ||
case vocabularySize = "vocab_size" | ||
case ropeTheta = "rope_theta" | ||
case useBias = "use_bias" | ||
case tieWordEmbeddings = "tie_word_embeddings" | ||
} | ||
|
||
public init(from decoder: Decoder) throws { | ||
let container: KeyedDecodingContainer<Ernie45Configuration.CodingKeys> = | ||
try decoder.container(keyedBy: Ernie45Configuration.CodingKeys.self) | ||
|
||
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) | ||
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) | ||
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) | ||
self.numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads) | ||
self.numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads) | ||
self.headDim = try container.decode(Int.self, forKey: .headDim) | ||
self.numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers) | ||
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) | ||
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) | ||
self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta) | ||
self.useBias = try container.decode(Bool.self, forKey: .useBias) ?? false | ||
self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings) | ||
} | ||
} | ||
|
||
private class Attention: Module { | ||
let nHeads: Int | ||
let nKVHeads: Int | ||
let headDim: Int | ||
let scale: Float | ||
|
||
@ModuleInfo(key: "q_proj") var qProj: Linear | ||
@ModuleInfo(key: "k_proj") var kProj: Linear | ||
@ModuleInfo(key: "v_proj") var vProj: Linear | ||
@ModuleInfo(key: "o_proj") var oProj: Linear | ||
|
||
let rope: RoPE | ||
|
||
public init(_ args: Ernie45Configuration) { | ||
let dim = args.hiddenSize | ||
self.nHeads = args.numAttentionHeads | ||
self.nKVHeads = args.numKeyValueHeads | ||
self.headDim = args.headDim ?? (dim / args.numAttentionHeads) | ||
self.scale = pow(Float(headDim), -0.5) | ||
|
||
self._qProj.wrappedValue = Linear(dim, nHeads * headDim, bias: args.useBias) | ||
self._kProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias) | ||
self._vProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias) | ||
self._oProj.wrappedValue = Linear(nHeads * headDim, dim, bias: args.useBias) | ||
|
||
self.rope = RoPE( | ||
dimensions: headDim, | ||
traditional: true, | ||
base: args.ropeTheta | ||
) | ||
} | ||
|
||
public func callAsFunction( | ||
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? | ||
) -> MLXArray { | ||
let (B, L) = (x.dim(0), x.dim(1)) | ||
|
||
var queries = qProj(x) | ||
var keys = kProj(x) | ||
var values = vProj(x) | ||
|
||
queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3) | ||
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3) | ||
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3) | ||
|
||
if let cache { | ||
queries = rope(queries, offset: cache.offset) | ||
keys = rope(keys, offset: cache.offset) | ||
} else { | ||
queries = rope(queries) | ||
keys = rope(keys) | ||
} | ||
|
||
let output = attentionWithCacheUpdate( | ||
queries: queries, | ||
keys: keys, | ||
values: values, | ||
cache: cache, | ||
scale: scale, | ||
mask: mask | ||
) | ||
.transposed(0, 2, 1, 3) | ||
.reshaped(B, L, -1) | ||
|
||
return oProj(output) | ||
} | ||
} | ||
|
||
private class MLP: Module, UnaryLayer { | ||
@ModuleInfo(key: "gate_proj") var gateProj: Linear | ||
@ModuleInfo(key: "down_proj") var downProj: Linear | ||
@ModuleInfo(key: "up_proj") var upProj: Linear | ||
|
||
public init(dim: Int, hiddenDim: Int, useBias: Bool = false) { | ||
self._gateProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias) | ||
self._downProj.wrappedValue = Linear(hiddenDim, dim, bias: useBias) | ||
self._upProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias) | ||
} | ||
|
||
public func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
downProj(silu(gateProj(x)) * upProj(x)) | ||
} | ||
} | ||
|
||
private class DecoderLayer: Module { | ||
@ModuleInfo(key: "self_attn") var attention: Attention | ||
let mlp: MLP | ||
|
||
@ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm | ||
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm | ||
|
||
public init(_ args: Ernie45Configuration) { | ||
self._attention.wrappedValue = Attention(args) | ||
self.mlp = MLP( | ||
dim: args.hiddenSize, hiddenDim: args.intermediateSize, useBias: args.useBias) | ||
self._inputLayernorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) | ||
self._postAttentionLayernorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) | ||
} | ||
|
||
public func callAsFunction( | ||
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? | ||
) -> MLXArray { | ||
var r = attention(inputLayernorm(x), mask: mask, cache: cache) | ||
let h = x + r | ||
r = mlp(postAttentionLayernorm(h)) | ||
return h + r | ||
} | ||
} | ||
|
||
private class Ernie45ModelInner: Module { | ||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding | ||
let layers: [DecoderLayer] | ||
let norm: RMSNorm | ||
|
||
public init(_ args: Ernie45Configuration) { | ||
self._embedTokens.wrappedValue = Embedding( | ||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize | ||
) | ||
self.layers = (0 ..< args.numHiddenLayers).map { _ in | ||
DecoderLayer(args) | ||
} | ||
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) | ||
} | ||
|
||
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { | ||
var h = embedTokens(inputs) | ||
|
||
let mask = createAttentionMask(h: h, cache: cache) | ||
|
||
for (i, layer) in layers.enumerated() { | ||
h = layer(h, mask: mask, cache: cache?[i]) | ||
} | ||
|
||
return norm(h) | ||
} | ||
} | ||
|
||
public class Ernie45Model: Module, LLMModel, KVCacheDimensionProvider { | ||
public let vocabularySize: Int | ||
public let kvHeads: [Int] | ||
|
||
private let model: Ernie45ModelInner | ||
@ModuleInfo(key: "lm_head") var lmHead: Linear? | ||
|
||
public init(_ args: Ernie45Configuration) { | ||
self.vocabularySize = args.vocabularySize | ||
self.kvHeads = Array(repeating: args.numKeyValueHeads, count: args.numHiddenLayers) | ||
self.model = Ernie45ModelInner(args) | ||
|
||
if !args.tieWordEmbeddings { | ||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) | ||
} | ||
} | ||
|
||
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { | ||
let out = model(inputs, cache: cache) | ||
|
||
if let lmHead { | ||
return lmHead(out) | ||
} else { | ||
return model.embedTokens.asLinear(out) | ||
} | ||
} | ||
} | ||
|
||
// MARK: - LoRA | ||
|
||
extension Ernie45Model: LoRAModel { | ||
public func loraLinearLayers() -> LoRALinearLayers { | ||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is recommended not to modify here.
#302 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@johnmai-dev Thanks for heads up on the recommendation, I've reverted it. As for generating tokenizer.json, this is the python script I used. I had downloaded the model prior so I used the last example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much! @smdesai
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @johnmai-dev I see the ERNIE model in mlx-community was added by you. Any chance you can add tokenizer.json to the models here? I can then change LLMModelFactory to reference the model in mlx-community
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the python script you provided, but it didn't generate tokenizer.json. Maybe I need to configure something else?

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, try this: https://colab.research.google.com/drive/1B9v_838cTn0KavQ26uWuFcRKJ7jrjVQb?usp=sharing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where can I download ERNIE-4.5-0.3B-PT-bf16 from this notebook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The files are identical to the ones here: https://huggingface.co/mlx-community/ERNIE-4.5-0.3B-PT-bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, it still cannot be generated.
I have added
huggingface-cli download
andpip install
commands in your notebook.Can you try running it with the notebook I provided?
Let me see if there will be any difference in the results when you run it with my notebook.
https://colab.research.google.com/drive/1fAHK6EL8JYsHDo5duJr5llI1YeeK974t?usp=sharing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I have no idea what's going on here. I tried your notebook and I get the same error as you. I also tried the same changes in my notebook and get the same error (not surprising). So the only thing that works is: