Skip to content

Commit 61a57b8

Browse files
ngxsonkooshi
authored andcommitted
model : add hunyuan moe (ggml-org#14425)
* model : add hunyuan moe * tokenizer ok * fix tensor name * cgraph init * chat template * wip * almost working * skip embed, fix bos * cleanup * yarn scaling * cleanup * correct rope type * failed token fix * ntk alpha freq_base * tokenization working * cleanup and pr changes * vocab_size sanity check * ntk alpha generic * Update convert_hf_to_gguf.py * Apply suggestions from code review * fix regression * fix style --------- Co-authored-by: kooshi <1934337+kooshi@users.noreply.github.com>
1 parent 8ac5f58 commit 61a57b8

12 files changed

+449
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
815815
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
816816
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
817817
res = "minerva-7b"
818+
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
819+
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820+
res = "hunyuan"
818821

819822
if res is None:
820823
logger.warning("\n")
@@ -6535,6 +6538,155 @@ def set_gguf_parameters(self):
65356538
super().set_gguf_parameters()
65366539
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
65376540

6541+
6542+
@ModelBase.register("HunYuanMoEV1ForCausalLM")
6543+
class HunYuanMoEModel(TextModel):
6544+
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
6545+
6546+
def __init__(self, *args, **kwargs):
6547+
super().__init__(*args, **kwargs)
6548+
# For handling tied embeddings
6549+
self._tok_embd = None
6550+
6551+
def set_vocab(self):
6552+
from transformers import AutoTokenizer
6553+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
6554+
6555+
# 1. Get the pre-tokenizer identifier hash
6556+
tokpre = self.get_vocab_base_pre(tokenizer)
6557+
6558+
# 2. Reverse-engineer the merges list from mergeable_ranks
6559+
merges = []
6560+
vocab = {}
6561+
mergeable_ranks = tokenizer.mergeable_ranks
6562+
for token, rank in mergeable_ranks.items():
6563+
vocab[QwenModel.token_bytes_to_string(token)] = rank
6564+
if len(token) == 1:
6565+
continue
6566+
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
6567+
if len(merged) == 2: # todo this is an assert in Qwen, why?
6568+
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
6569+
6570+
# 3. Generate the tokens and toktypes lists
6571+
vocab_size = self.hparams["vocab_size"]
6572+
assert tokenizer.vocab_size == vocab_size
6573+
special_tokens = tokenizer.special_tokens
6574+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
6575+
tokens: list[str] = []
6576+
toktypes: list[int] = []
6577+
for i in range(vocab_size):
6578+
if i not in reverse_vocab:
6579+
tokens.append(f"[PAD{i}]")
6580+
toktypes.append(gguf.TokenType.UNUSED)
6581+
else:
6582+
token = reverse_vocab[i]
6583+
tokens.append(token)
6584+
if i in special_tokens.values():
6585+
toktypes.append(gguf.TokenType.CONTROL)
6586+
else:
6587+
toktypes.append(gguf.TokenType.NORMAL)
6588+
6589+
# 4. Write all vocab-related fields to the GGUF writer
6590+
self.gguf_writer.add_tokenizer_model("gpt2")
6591+
self.gguf_writer.add_tokenizer_pre(tokpre)
6592+
self.gguf_writer.add_token_list(tokens)
6593+
self.gguf_writer.add_token_types(toktypes)
6594+
self.gguf_writer.add_token_merges(merges)
6595+
6596+
# 5. Add special tokens and chat templates
6597+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
6598+
special_vocab.add_to_gguf(self.gguf_writer)
6599+
# FIX for BOS token: Overwrite incorrect id read from config.json
6600+
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
6601+
6602+
def set_gguf_parameters(self):
6603+
super().set_gguf_parameters()
6604+
hparams = self.hparams
6605+
6606+
self.gguf_writer.add_expert_count(hparams["num_experts"])
6607+
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
6608+
6609+
moe_intermediate_size = hparams["moe_intermediate_size"]
6610+
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
6611+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
6612+
6613+
moe_topk = hparams["moe_topk"]
6614+
assert all(topk == moe_topk[0] for topk in moe_topk)
6615+
self.gguf_writer.add_expert_used_count(moe_topk[0])
6616+
6617+
moe_shared_expert = hparams["num_shared_expert"]
6618+
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
6619+
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
6620+
6621+
# Rope
6622+
rope_scaling = hparams.get("rope_scaling", {})
6623+
if rope_scaling.get("type") == "dynamic":
6624+
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6625+
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6626+
alpha = rope_scaling.get("alpha", 1000)
6627+
base = hparams.get("rope_theta", 10000.0)
6628+
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
6629+
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6630+
self.gguf_writer.add_rope_freq_base(scaled_base)
6631+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
6632+
self.gguf_writer.add_rope_scaling_factor(1)
6633+
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6634+
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
6635+
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
6636+
6637+
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6638+
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
6639+
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6640+
6641+
_experts: list[dict[str, Tensor]] | None = None
6642+
6643+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6644+
if name == "model.embed_tokens.weight":
6645+
self._tok_embd = data_torch.clone()
6646+
6647+
if name == "lm_head.weight":
6648+
if self.hparams.get("tie_word_embeddings", False):
6649+
logger.info("Skipping tied output layer 'lm_head.weight'")
6650+
return []
6651+
6652+
if name.find("mlp.experts") != -1:
6653+
n_experts = self.hparams["num_experts"]
6654+
assert bid is not None
6655+
6656+
if self._experts is None:
6657+
self._experts = [{} for _ in range(self.block_count)]
6658+
6659+
self._experts[bid][name] = data_torch
6660+
6661+
if len(self._experts[bid]) >= n_experts * 3:
6662+
# merge the experts into a single 3d tensor
6663+
tensors: list[tuple[str, Tensor]] = []
6664+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6665+
datas: list[Tensor] = []
6666+
6667+
for xid in range(n_experts):
6668+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6669+
datas.append(self._experts[bid][ename])
6670+
del self._experts[bid][ename]
6671+
6672+
data_torch = torch.stack(datas, dim=0)
6673+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6674+
new_name = self.map_tensor_name(merged_name)
6675+
tensors.append((new_name, data_torch))
6676+
6677+
return tensors
6678+
else:
6679+
return []
6680+
6681+
return [(self.map_tensor_name(name), data_torch)]
6682+
6683+
def prepare_tensors(self):
6684+
super().prepare_tensors()
6685+
if self._experts is not None:
6686+
experts = [k for d in self._experts for k in d.keys()]
6687+
if len(experts) > 0:
6688+
raise ValueError(f"Unprocessed experts: {experts}")
6689+
65386690
###### CONVERSION LOGIC ######
65396691

65406692

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
137137
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
138138
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
139139
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
140+
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
140141
]
141142

142143

gguf-py/gguf/constants.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ class MODEL_ARCH(IntEnum):
357357
DOTS1 = auto()
358358
ARCEE = auto()
359359
ERNIE4_5 = auto()
360+
HUNYUAN_MOE = auto()
360361

361362

362363
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -660,6 +661,7 @@ class MODEL_TENSOR(IntEnum):
660661
MODEL_ARCH.DOTS1: "dots1",
661662
MODEL_ARCH.ARCEE: "arcee",
662663
MODEL_ARCH.ERNIE4_5: "ernie4_5",
664+
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
663665
}
664666

665667
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2211,6 +2213,27 @@ class MODEL_TENSOR(IntEnum):
22112213
MODEL_TENSOR.FFN_DOWN,
22122214
MODEL_TENSOR.FFN_UP,
22132215
],
2216+
MODEL_ARCH.HUNYUAN_MOE: [
2217+
MODEL_TENSOR.TOKEN_EMBD,
2218+
MODEL_TENSOR.OUTPUT_NORM,
2219+
MODEL_TENSOR.OUTPUT,
2220+
MODEL_TENSOR.ROPE_FREQS,
2221+
MODEL_TENSOR.ATTN_NORM,
2222+
MODEL_TENSOR.ATTN_Q,
2223+
MODEL_TENSOR.ATTN_Q_NORM,
2224+
MODEL_TENSOR.ATTN_K,
2225+
MODEL_TENSOR.ATTN_K_NORM,
2226+
MODEL_TENSOR.ATTN_V,
2227+
MODEL_TENSOR.ATTN_OUT,
2228+
MODEL_TENSOR.FFN_GATE_INP,
2229+
MODEL_TENSOR.FFN_NORM,
2230+
MODEL_TENSOR.FFN_GATE_EXP,
2231+
MODEL_TENSOR.FFN_DOWN_EXP,
2232+
MODEL_TENSOR.FFN_UP_EXP,
2233+
MODEL_TENSOR.FFN_GATE_SHEXP,
2234+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2235+
MODEL_TENSOR.FFN_UP_SHEXP,
2236+
],
22142237
# TODO
22152238
}
22162239

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ class TensorNameMap:
303303
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
304304
"model.layers.{bid}.feed_forward.router", # llama4
305305
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
306+
"model.layers.{bid}.mlp.gate.wg", # hunyuan
306307
),
307308

308309
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -362,6 +363,7 @@ class TensorNameMap:
362363
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
363364
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
364365
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
366+
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
365367
),
366368

367369
# AWQ-activation gate
@@ -398,6 +400,7 @@ class TensorNameMap:
398400
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
399401
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
400402
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
403+
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
401404
),
402405

403406
# Feed-forward down
@@ -447,11 +450,13 @@ class TensorNameMap:
447450
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
448451
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
449452
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
453+
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
450454
),
451455

452456
MODEL_TENSOR.ATTN_Q_NORM: (
453457
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
454458
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
459+
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
455460
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
456461
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
457462
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
@@ -461,6 +466,7 @@ class TensorNameMap:
461466
MODEL_TENSOR.ATTN_K_NORM: (
462467
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
463468
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
469+
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
464470
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
465471
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
466472
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ extern "C" {
117117
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
118118
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
119119
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
120+
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
120121
};
121122

122123
enum llama_rope_type {

src/llama-arch.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7878
{ LLM_ARCH_DOTS1, "dots1" },
7979
{ LLM_ARCH_ARCEE, "arcee" },
8080
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
81+
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
8182
{ LLM_ARCH_UNKNOWN, "(unknown)" },
8283
};
8384

@@ -1694,6 +1695,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
16941695
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
16951696
},
16961697
},
1698+
{
1699+
LLM_ARCH_HUNYUAN_MOE,
1700+
{
1701+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1702+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1703+
{ LLM_TENSOR_OUTPUT, "output" },
1704+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1705+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1706+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1707+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1708+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1709+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1710+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1711+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1712+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1713+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1714+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1715+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1716+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1717+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1718+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1719+
},
1720+
},
16971721
{
16981722
LLM_ARCH_UNKNOWN,
16991723
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ enum llm_arch {
8282
LLM_ARCH_DOTS1,
8383
LLM_ARCH_ARCEE,
8484
LLM_ARCH_ERNIE4_5,
85+
LLM_ARCH_HUNYUAN_MOE,
8586
LLM_ARCH_UNKNOWN,
8687
};
8788

src/llama-chat.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
6464
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
6565
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
6666
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
67+
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
6768
};
6869

6970
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
185186
return LLM_CHAT_TEMPLATE_LLAMA4;
186187
} else if (tmpl_contains("<|endofuserprompt|>")) {
187188
return LLM_CHAT_TEMPLATE_DOTS1;
189+
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
190+
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
188191
}
189192
return LLM_CHAT_TEMPLATE_UNKNOWN;
190193
}
@@ -665,6 +668,21 @@ int32_t llm_chat_apply_template(
665668
if (add_ass) {
666669
ss << "<|response|>";
667670
}
671+
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
672+
// tencent/Hunyuan-A13B-Instruct
673+
for (auto message : chat) {
674+
std::string role(message->role);
675+
if (role == "system") {
676+
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
677+
} else if (role == "assistant") {
678+
ss << "<|startoftext|>" << message->content << "<|eos|>";
679+
} else {
680+
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
681+
}
682+
}
683+
if (add_ass) {
684+
ss << "<|startoftext|>";
685+
}
668686
} else {
669687
// template not supported
670688
return -1;

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ enum llm_chat_template {
4444
LLM_CHAT_TEMPLATE_LLAMA4,
4545
LLM_CHAT_TEMPLATE_SMOLVLM,
4646
LLM_CHAT_TEMPLATE_DOTS1,
47+
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
4748
LLM_CHAT_TEMPLATE_UNKNOWN,
4849
};
4950

0 commit comments

Comments
 (0)