Skip to content

Commit cb887f1

Browse files
pwilkinCISC
andauthored
model: add Ernie 4.5 MoE support (#14658)
* Add Ernie4.5 MoE * Fix Flake errors. * Properly encode/decode MoE layer step * Correct tensor mappings (.weight) * Pass and read n_ff_exp * n_ff_shexp calculation and further minor changes * Rope fixes. * .gitignore fix * Add unit32 cast for Linux builds * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Further fixes from code review * Fix trailing whitespace * Reenable missing experts error * Code style from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Fix non-MoE regression Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent d6fb3f6 commit cb887f1

File tree

7 files changed

+373
-26
lines changed

7 files changed

+373
-26
lines changed

convert_hf_to_gguf.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2861,7 +2861,8 @@ def set_gguf_parameters(self):
28612861
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
28622862
num_heads = self.hparams["num_attention_heads"]
28632863
num_kv_heads = self.hparams["num_key_value_heads"]
2864-
head_dim = self.hparams["head_dim"]
2864+
if (head_dim := self.hparams.get("head_dim")) is None:
2865+
head_dim = self.hparams["hidden_size"] // num_heads
28652866

28662867
if "ernie." in name:
28672868
name = name.replace("ernie.", "model.")
@@ -2894,6 +2895,92 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
28942895
return [(self.map_tensor_name(name), data_torch)]
28952896

28962897

2898+
@ModelBase.register("Ernie4_5_MoeForCausalLM")
2899+
class Ernie4_5MoeModel(Ernie4_5Model):
2900+
model_arch = gguf.MODEL_ARCH.ERNIE4_5_MOE
2901+
_experts: list[dict[str, Tensor]] | None = None
2902+
2903+
def __init__(self, *args, **kwargs):
2904+
super().__init__(*args, **kwargs)
2905+
self._experts = [{} for _ in range(self.block_count)]
2906+
2907+
def set_gguf_parameters(self):
2908+
super().set_gguf_parameters()
2909+
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
2910+
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
2911+
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
2912+
self.gguf_writer.add_leading_dense_block_count(self.hparams["moe_layer_start_index"])
2913+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
2914+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
2915+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2916+
if (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:
2917+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads)
2918+
2919+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2920+
# Modify correction bias name as in DeepseekV2
2921+
if name.endswith("e_score_correction_bias"):
2922+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
2923+
2924+
# skip Multi-Token Prediction (MTP) layers (again, same as DeepseekV2)
2925+
match = re.match(r"model.mtp_block.(\d+)", name)
2926+
if match:
2927+
return []
2928+
2929+
# skip all other MTP tensors for now
2930+
match = re.match(r"model.mtp_emb_norm.(\d+)", name)
2931+
if match:
2932+
return []
2933+
2934+
match = re.match(r"model.mtp_hidden_norm.(\d+)", name)
2935+
if match:
2936+
return []
2937+
2938+
match = re.match(r"model.mtp_linear_proj.(\d+)", name)
2939+
if match:
2940+
return []
2941+
2942+
# process the experts separately
2943+
if name.find("mlp.experts") != -1:
2944+
n_experts = self.hparams["moe_num_experts"]
2945+
assert bid is not None
2946+
2947+
if self._experts is None:
2948+
self._experts = [{} for _ in range(self.block_count)]
2949+
2950+
self._experts[bid][name] = data_torch
2951+
2952+
if len(self._experts[bid]) >= n_experts * 3:
2953+
tensors: list[tuple[str, Tensor]] = []
2954+
2955+
# merge the experts into a single 3d tensor
2956+
for w_name in ["gate_proj", "up_proj", "down_proj"]:
2957+
datas: list[Tensor] = []
2958+
2959+
for xid in range(n_experts):
2960+
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2961+
datas.append(self._experts[bid][ename_to_retrieve])
2962+
del self._experts[bid][ename_to_retrieve]
2963+
2964+
data_torch = torch.stack(datas, dim=0)
2965+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2966+
new_name = self.map_tensor_name(merged_name)
2967+
tensors.append((new_name, data_torch))
2968+
2969+
return tensors
2970+
else:
2971+
return []
2972+
return [(self.map_tensor_name(name), data_torch)]
2973+
2974+
def prepare_tensors(self):
2975+
super().prepare_tensors()
2976+
2977+
if self._experts is not None:
2978+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2979+
experts = [k for d in self._experts for k in d.keys()]
2980+
if len(experts) > 0:
2981+
raise ValueError(f"Unprocessed experts: {experts}")
2982+
2983+
28972984
@ModelBase.register(
28982985
"Qwen2VLModel",
28992986
"Qwen2VLForConditionalGeneration",

gguf-py/gguf/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ class MODEL_ARCH(IntEnum):
364364
DOTS1 = auto()
365365
ARCEE = auto()
366366
ERNIE4_5 = auto()
367+
ERNIE4_5_MOE = auto()
367368
HUNYUAN_MOE = auto()
368369
SMOLLM3 = auto()
369370
LFM2 = auto()
@@ -680,6 +681,7 @@ class MODEL_TENSOR(IntEnum):
680681
MODEL_ARCH.DOTS1: "dots1",
681682
MODEL_ARCH.ARCEE: "arcee",
682683
MODEL_ARCH.ERNIE4_5: "ernie4_5",
684+
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
683685
MODEL_ARCH.FALCON_H1: "falcon-h1",
684686
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
685687
MODEL_ARCH.SMOLLM3: "smollm3",
@@ -2022,6 +2024,28 @@ class MODEL_TENSOR(IntEnum):
20222024
MODEL_TENSOR.FFN_UP_SHEXP,
20232025
MODEL_TENSOR.FFN_EXP_PROBS_B,
20242026
],
2027+
MODEL_ARCH.ERNIE4_5_MOE: [
2028+
MODEL_TENSOR.TOKEN_EMBD,
2029+
MODEL_TENSOR.OUTPUT_NORM,
2030+
MODEL_TENSOR.OUTPUT,
2031+
MODEL_TENSOR.ATTN_NORM,
2032+
MODEL_TENSOR.ATTN_Q,
2033+
MODEL_TENSOR.ATTN_K,
2034+
MODEL_TENSOR.ATTN_V,
2035+
MODEL_TENSOR.ATTN_OUT,
2036+
MODEL_TENSOR.FFN_NORM,
2037+
MODEL_TENSOR.FFN_GATE,
2038+
MODEL_TENSOR.FFN_DOWN,
2039+
MODEL_TENSOR.FFN_UP,
2040+
MODEL_TENSOR.FFN_GATE_INP,
2041+
MODEL_TENSOR.FFN_GATE_EXP,
2042+
MODEL_TENSOR.FFN_DOWN_EXP,
2043+
MODEL_TENSOR.FFN_UP_EXP,
2044+
MODEL_TENSOR.FFN_GATE_SHEXP,
2045+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2046+
MODEL_TENSOR.FFN_UP_SHEXP,
2047+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2048+
],
20252049
MODEL_ARCH.PLM: [
20262050
MODEL_TENSOR.TOKEN_EMBD,
20272051
MODEL_TENSOR.OUTPUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ class TensorNameMap:
324324
),
325325

326326
MODEL_TENSOR.FFN_EXP_PROBS_B: (
327-
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
327+
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
328+
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
328329
),
329330

330331
# Feed-forward up
@@ -364,13 +365,13 @@ class TensorNameMap:
364365
),
365366

366367
MODEL_TENSOR.FFN_UP_EXP: (
367-
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
368-
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
369-
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
370-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
371-
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
372-
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
373-
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
368+
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
369+
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
370+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
371+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe
372+
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
373+
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
374+
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
374375
),
375376

376377
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -403,12 +404,12 @@ class TensorNameMap:
403404
),
404405

405406
MODEL_TENSOR.FFN_GATE_EXP: (
406-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
407-
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
408-
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
409-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
410-
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
411-
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
407+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
408+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
409+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
410+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
411+
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
412+
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
412413
),
413414

414415
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -450,14 +451,14 @@ class TensorNameMap:
450451
),
451452

452453
MODEL_TENSOR.FFN_DOWN_EXP: (
453-
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
454-
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
455-
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
456-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
457-
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
458-
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
459-
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
460-
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
454+
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
455+
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
456+
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
457+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe
458+
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
459+
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
460+
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
461+
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
461462
),
462463

463464
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8282
{ LLM_ARCH_DOTS1, "dots1" },
8383
{ LLM_ARCH_ARCEE, "arcee" },
8484
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
85+
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
8586
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
8687
{ LLM_ARCH_SMOLLM3, "smollm3" },
8788
{ LLM_ARCH_LFM2, "lfm2" },
@@ -1825,6 +1826,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
18251826
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
18261827
},
18271828
},
1829+
{
1830+
LLM_ARCH_ERNIE4_5_MOE,
1831+
{
1832+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1833+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1834+
{ LLM_TENSOR_OUTPUT, "output" },
1835+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1836+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1837+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1838+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1839+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1840+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1841+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1842+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1843+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1844+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1845+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1846+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1847+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1848+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1849+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1850+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1851+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1852+
},
1853+
},
18281854
{
18291855
LLM_ARCH_HUNYUAN_MOE,
18301856
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ enum llm_arch {
8686
LLM_ARCH_DOTS1,
8787
LLM_ARCH_ARCEE,
8888
LLM_ARCH_ERNIE4_5,
89+
LLM_ARCH_ERNIE4_5_MOE,
8990
LLM_ARCH_HUNYUAN_MOE,
9091
LLM_ARCH_SMOLLM3,
9192
LLM_ARCH_LFM2,

0 commit comments

Comments
 (0)