@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
815
815
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35" :
816
816
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
817
817
res = "minerva-7b"
818
+ if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
819
+ # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820
+ res = "hunyuan"
818
821
819
822
if res is None :
820
823
logger .warning ("\n " )
@@ -6788,6 +6791,160 @@ def set_gguf_parameters(self):
6788
6791
super ().set_gguf_parameters ()
6789
6792
self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
6790
6793
6794
+
6795
+ @ModelBase .register ("HunYuanMoEV1ForCausalLM" )
6796
+ class HunYuanMoEModel (TextModel ):
6797
+ model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
6798
+
6799
+ def __init__ (self , * args , ** kwargs ):
6800
+ super ().__init__ (* args , ** kwargs )
6801
+ # For handling tied embeddings
6802
+ self ._tok_embd = None
6803
+
6804
+ def set_vocab (self ):
6805
+ from transformers import AutoTokenizer
6806
+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
6807
+
6808
+ # 1. Get the pre-tokenizer identifier hash
6809
+ tokpre = self .get_vocab_base_pre (tokenizer )
6810
+
6811
+ # 2. Reverse-engineer the merges list from mergeable_ranks
6812
+ merges = []
6813
+ vocab = {}
6814
+ mergeable_ranks = tokenizer .mergeable_ranks
6815
+ for token , rank in mergeable_ranks .items ():
6816
+ vocab [QwenModel .token_bytes_to_string (token )] = rank
6817
+ if len (token ) == 1 :
6818
+ continue
6819
+ merged = QwenModel .bpe (mergeable_ranks , token , max_rank = rank )
6820
+ if len (merged ) == 2 : # todo this is an assert in Qwen, why?
6821
+ merges .append (' ' .join (map (QwenModel .token_bytes_to_string , merged )))
6822
+
6823
+ # 3. Generate the tokens and toktypes lists
6824
+ vocab_size = self .hparams ["vocab_size" ]
6825
+ assert tokenizer .vocab_size == vocab_size
6826
+ special_tokens = tokenizer .special_tokens
6827
+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in {** vocab , ** special_tokens }.items ()}
6828
+ tokens : list [str ] = []
6829
+ toktypes : list [int ] = []
6830
+ for i in range (vocab_size ):
6831
+ if i not in reverse_vocab :
6832
+ tokens .append (f"[PAD{ i } ]" )
6833
+ toktypes .append (gguf .TokenType .UNUSED )
6834
+ else :
6835
+ token = reverse_vocab [i ]
6836
+ tokens .append (token )
6837
+ if i in special_tokens .values ():
6838
+ toktypes .append (gguf .TokenType .CONTROL )
6839
+ else :
6840
+ toktypes .append (gguf .TokenType .NORMAL )
6841
+
6842
+ # 4. Write all vocab-related fields to the GGUF writer
6843
+ self .gguf_writer .add_tokenizer_model ("gpt2" )
6844
+ self .gguf_writer .add_tokenizer_pre (tokpre )
6845
+ self .gguf_writer .add_token_list (tokens )
6846
+ self .gguf_writer .add_token_types (toktypes )
6847
+ self .gguf_writer .add_token_merges (merges )
6848
+
6849
+ # 5. Add special tokens and chat templates
6850
+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
6851
+ special_vocab .add_to_gguf (self .gguf_writer )
6852
+ # FIX for BOS token: Overwrite incorrect id read from config.json
6853
+ self .gguf_writer .add_bos_token_id (127959 ) # <|bos|>
6854
+
6855
+ def set_gguf_parameters (self ):
6856
+ super ().set_gguf_parameters ()
6857
+ hparams = self .hparams
6858
+
6859
+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
6860
+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams ["intermediate_size" ])
6861
+
6862
+ moe_intermediate_size = hparams ["moe_intermediate_size" ]
6863
+ assert all (n == moe_intermediate_size [0 ] for n in moe_intermediate_size )
6864
+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size [0 ])
6865
+
6866
+ moe_topk = hparams ["moe_topk" ]
6867
+ assert all (topk == moe_topk [0 ] for topk in moe_topk )
6868
+ self .gguf_writer .add_expert_used_count (moe_topk [0 ])
6869
+
6870
+ moe_shared_expert = hparams ["num_shared_expert" ]
6871
+ assert all (n == moe_shared_expert [0 ] for n in moe_shared_expert )
6872
+ self .gguf_writer .add_expert_shared_count (moe_shared_expert [0 ])
6873
+
6874
+ # Rope
6875
+ rope_scaling = hparams .get ("rope_scaling" , {})
6876
+ if rope_scaling .get ("type" ) == "dynamic" :
6877
+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6878
+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6879
+ alpha = rope_scaling .get ("alpha" , 1000 )
6880
+ base = hparams .get ("rope_theta" , 10000.0 )
6881
+ dim = (hparams ["hidden_size" ] // hparams ["num_attention_heads" ]) # 128
6882
+ scaled_base = base * (alpha ** (dim / (dim - 2 ))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6883
+ self .gguf_writer .add_rope_freq_base (scaled_base )
6884
+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
6885
+ self .gguf_writer .add_rope_scaling_factor (1 )
6886
+ # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6887
+ self .gguf_writer .add_rope_scaling_orig_ctx_len (256 * 1024 ) # 256k context length
6888
+ self .gguf_writer .add_context_length (256 * 1024 ) # 256k context length
6889
+
6890
+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6891
+ assert alpha == 1000 and base == 10000.0 and dim == 128 and self .hparams ["max_position_embeddings" ] in [32 * 1024 , 256 * 1024 ] , \
6892
+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6893
+
6894
+ _experts : list [dict [str , Tensor ]] | None = None
6895
+
6896
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6897
+ if name == "model.embed_tokens.weight" :
6898
+ self ._tok_embd = data_torch .clone ()
6899
+
6900
+ if name == "lm_head.weight" :
6901
+ if self .hparams .get ("tie_word_embeddings" , False ):
6902
+ logger .info ("Skipping tied output layer 'lm_head.weight'" )
6903
+ return []
6904
+
6905
+ if name .find ("mlp.experts" ) != - 1 :
6906
+ n_experts = self .hparams ["num_experts" ]
6907
+ assert bid is not None
6908
+
6909
+ if self ._experts is None :
6910
+ self ._experts = [{} for _ in range (self .block_count )]
6911
+
6912
+ self ._experts [bid ][name ] = data_torch
6913
+
6914
+ if len (self ._experts [bid ]) >= n_experts * 3 :
6915
+ # merge the experts into a single 3d tensor
6916
+ tensors : list [tuple [str , Tensor ]] = []
6917
+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
6918
+ datas : list [Tensor ] = []
6919
+
6920
+ for xid in range (n_experts ):
6921
+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
6922
+ datas .append (self ._experts [bid ][ename ])
6923
+ del self ._experts [bid ][ename ]
6924
+
6925
+ data_torch = torch .stack (datas , dim = 0 )
6926
+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
6927
+ new_name = self .map_tensor_name (merged_name )
6928
+ tensors .append ((new_name , data_torch ))
6929
+
6930
+ return tensors
6931
+ else :
6932
+ return []
6933
+
6934
+ return [(self .map_tensor_name (name ), data_torch )]
6935
+
6936
+ def prepare_tensors (self ):
6937
+ super ().prepare_tensors ()
6938
+ if self ._experts is not None :
6939
+ experts = [k for d in self ._experts for k in d .keys ()]
6940
+ if len (experts ) > 0 :
6941
+ raise ValueError (f"Unprocessed experts: { experts } " )
6942
+
6943
+
6944
+ @ModelBase .register ("SmolLM3ForCausalLM" )
6945
+ class SmolLM3Model (LlamaModel ):
6946
+ model_arch = gguf .MODEL_ARCH .SMOLLM3
6947
+
6791
6948
###### CONVERSION LOGIC ######
6792
6949
6793
6950
0 commit comments