Skip to content

Port of Ernie4 5 #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class LLMEvaluator {

/// This controls which model loads. `qwen2_5_1_5b` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = LLMRegistry.qwen3_1_7b_4bit
let modelConfiguration = LLMRegistry.ernie4503BPTbf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended not to modify here.
#302 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnmai-dev Thanks for heads up on the recommendation, I've reverted it. As for generating tokenizer.json, this is the python script I used. I had downloaded the model prior so I used the last example.

from transformers import AutoTokenizer
import json
import os

def convert_tokenizer_model_to_json(model_path, output_path=None):
    """
    Convert a tokenizer.model file to tokenizer.json format.

    Args:
        model_path: Path to the tokenizer.model file or directory containing it
        output_path: Optional output path for tokenizer.json (defaults to same directory)
    """
    # Handle both file and directory paths
    if os.path.isdir(model_path):
        tokenizer_model_path = os.path.join(model_path, "tokenizer.model")
    else:
        tokenizer_model_path = model_path
        model_path = os.path.dirname(model_path)

    if not os.path.exists(tokenizer_model_path):
        raise FileNotFoundError(f"tokenizer.model not found at {tokenizer_model_path}")

    tokenizer = AutoTokenizer.from_pretrained(model_path)

    if output_path is None:
        output_path = model_path

    tokenizer.save_pretrained(output_path)

    tokenizer_json_path = os.path.join(output_path, "tokenizer.json")
    if os.path.exists(tokenizer_json_path):
        print(f"Successfully created tokenizer.json at {tokenizer_json_path}")
    else:
        print("Warning: tokenizer.json was not created. The tokenizer might not support this format.")

    return tokenizer_json_path

# Example usage
if __name__ == "__main__":
    # Example: Convert a tokenizer.model file
    # convert_tokenizer_model_to_json("/path/to/tokenizer.model")

    # Example: Convert from a directory containing tokenizer.model
    # convert_tokenizer_model_to_json("/path/to/model/directory")
    

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much! @smdesai

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @johnmai-dev I see the ERNIE model in mlx-community was added by you. Any chance you can add tokenizer.json to the models here? I can then change LLMModelFactory to reference the model in mlx-community

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the python script you provided, but it didn't generate tokenizer.json. Maybe I need to configure something else?
访达 2025-07-06 13 58 18

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I download ERNIE-4.5-0.3B-PT-bf16 from this notebook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The files are identical to the ones here: https://huggingface.co/mlx-community/ERNIE-4.5-0.3B-PT-bf16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, it still cannot be generated.
I have added huggingface-cli download and pip install commands in your notebook.

Can you try running it with the notebook I provided?
Let me see if there will be any difference in the results when you run it with my notebook.
https://colab.research.google.com/drive/1fAHK6EL8JYsHDo5duJr5llI1YeeK974t?usp=sharing

Google Chrome 2025-07-09 00 25 40 image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I have no idea what's going on here. I tried your notebook and I get the same error as you. I also tried the same changes in my notebook and get the same error (not surprising). So the only thing that works is:

  • downloading the model files via huggingface to a local directory and converting
  • uploading the model files to Colab and performing the conversion


/// parameters controlling the output
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6)
Expand Down Expand Up @@ -232,7 +232,7 @@ class LLMEvaluator {
let timeTool = Tool<EmptyInput, TimeOutput>(
name: "get_time",
description: "Get the current time",
parameters: [],
parameters: []
) { _ in
TimeOutput(time: Date.now.formatted())
}
Expand Down
7 changes: 7 additions & 0 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
"glm4": create(GLM4Configuration.self, GLM4Model.init),
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
]
}

Expand Down Expand Up @@ -231,6 +232,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
defaultPrompt: "Why is the sky blue?"
)

static public let ernie4503BPTbf16 = ModelConfiguration(
id: "smdesai/ERNIE-4.5-0.3B-PT-bf16",
defaultPrompt: "Why is the sky blue?"
)

private static func all() -> [ModelConfiguration] {
[
codeLlama13b4bit,
Expand Down Expand Up @@ -263,6 +269,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
glm4_9b_4bit,
acereason_7b_4bit,
bitnet_b1_58_2b_4t_4bit,
ernie4503BPTbf16
]
}

Expand Down
234 changes: 234 additions & 0 deletions Libraries/MLXLLM/Models/Ernie4_5.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//
// Ernie4_5.swift
// mlx-swift-examples
//
// Created by Sachin Desai on 7/3/25.
//

import Foundation
import MLX
import MLXLMCommon
import MLXNN

// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ernie4_5.py

public struct Ernie45Configuration: Codable, Sendable {
var hiddenSize: Int
var intermediateSize: Int
var maxPositionEmbeddings: Int
var numAttentionHeads: Int
var numKeyValueHeads: Int
var headDim: Int?
var numHiddenLayers: Int
var rmsNormEps: Float
var vocabularySize: Int
var ropeTheta: Float
var useBias: Bool
var tieWordEmbeddings: Bool

enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case intermediateSize = "intermediate_size"
case maxPositionEmbeddings = "max_position_embeddings"
case numAttentionHeads = "num_attention_heads"
case numKeyValueHeads = "num_key_value_heads"
case headDim = "head_dim"
case numHiddenLayers = "num_hidden_layers"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case ropeTheta = "rope_theta"
case useBias = "use_bias"
case tieWordEmbeddings = "tie_word_embeddings"
}

public init(from decoder: Decoder) throws {
let container: KeyedDecodingContainer<Ernie45Configuration.CodingKeys> =
try decoder.container(keyedBy: Ernie45Configuration.CodingKeys.self)

self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
self.numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads)
self.numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads)
self.headDim = try container.decode(Int.self, forKey: .headDim)
self.numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers)
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta)
self.useBias = try container.decode(Bool.self, forKey: .useBias) ?? false
self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings)
}
}

private class Attention: Module {
let nHeads: Int
let nKVHeads: Int
let headDim: Int
let scale: Float

@ModuleInfo(key: "q_proj") var qProj: Linear
@ModuleInfo(key: "k_proj") var kProj: Linear
@ModuleInfo(key: "v_proj") var vProj: Linear
@ModuleInfo(key: "o_proj") var oProj: Linear

let rope: RoPE

public init(_ args: Ernie45Configuration) {
let dim = args.hiddenSize
self.nHeads = args.numAttentionHeads
self.nKVHeads = args.numKeyValueHeads
self.headDim = args.headDim ?? (dim / args.numAttentionHeads)
self.scale = pow(Float(headDim), -0.5)

self._qProj.wrappedValue = Linear(dim, nHeads * headDim, bias: args.useBias)
self._kProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias)
self._vProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias)
self._oProj.wrappedValue = Linear(nHeads * headDim, dim, bias: args.useBias)

self.rope = RoPE(
dimensions: headDim,
traditional: true,
base: args.ropeTheta
)
}

public func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
let (B, L) = (x.dim(0), x.dim(1))

var queries = qProj(x)
var keys = kProj(x)
var values = vProj(x)

queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)

if let cache {
queries = rope(queries, offset: cache.offset)
keys = rope(keys, offset: cache.offset)
} else {
queries = rope(queries)
keys = rope(keys)
}

let output = attentionWithCacheUpdate(
queries: queries,
keys: keys,
values: values,
cache: cache,
scale: scale,
mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)

return oProj(output)
}
}

private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gateProj: Linear
@ModuleInfo(key: "down_proj") var downProj: Linear
@ModuleInfo(key: "up_proj") var upProj: Linear

public init(dim: Int, hiddenDim: Int, useBias: Bool = false) {
self._gateProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias)
self._downProj.wrappedValue = Linear(hiddenDim, dim, bias: useBias)
self._upProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias)
}

public func callAsFunction(_ x: MLXArray) -> MLXArray {
downProj(silu(gateProj(x)) * upProj(x))
}
}

private class DecoderLayer: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP

@ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm

public init(_ args: Ernie45Configuration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(
dim: args.hiddenSize, hiddenDim: args.intermediateSize, useBias: args.useBias)
self._inputLayernorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayernorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

public func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
var r = attention(inputLayernorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayernorm(h))
return h + r
}
}

private class Ernie45ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
let layers: [DecoderLayer]
let norm: RMSNorm

public init(_ args: Ernie45Configuration) {
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize
)
self.layers = (0 ..< args.numHiddenLayers).map { _ in
DecoderLayer(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
var h = embedTokens(inputs)

let mask = createAttentionMask(h: h, cache: cache)

for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}

return norm(h)
}
}

public class Ernie45Model: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]

private let model: Ernie45ModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear?

public init(_ args: Ernie45Configuration) {
self.vocabularySize = args.vocabularySize
self.kvHeads = Array(repeating: args.numKeyValueHeads, count: args.numHiddenLayers)
self.model = Ernie45ModelInner(args)

if !args.tieWordEmbeddings {
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
let out = model(inputs, cache: cache)

if let lmHead {
return lmHead(out)
} else {
return model.embedTokens.asLinear(out)
}
}
}

// MARK: - LoRA

extension Ernie45Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}