@@ -818,6 +818,18 @@ def get_vocab_base_pre(self, tokenizer) -> str:
818
818
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
819
819
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820
820
res = "hunyuan"
821
+ if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" :
822
+ # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
823
+ res = "falcon-h1"
824
+ if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86" :
825
+ # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
826
+ res = "falcon-h1"
827
+ if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896" :
828
+ # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
829
+ res = "falcon-h1"
830
+ if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b" :
831
+ # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
832
+ res = "falcon-h1"
821
833
822
834
if res is None :
823
835
logger .warning ("\n " )
@@ -4899,17 +4911,19 @@ def set_vocab(self):
4899
4911
def set_gguf_parameters (self ):
4900
4912
d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4901
4913
d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
4902
- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4914
+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4903
4915
d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
4904
- head_dim = self .find_hparam (["head_dim" ], optional = True ) or 64
4916
+ head_dim = self .find_hparam (["mamba_d_head" , " head_dim" ], optional = True ) or 64
4905
4917
n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
4906
4918
4907
4919
rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
4908
4920
4909
4921
# Fail early for models which don't have a block expansion factor of 2
4910
4922
# TODO: does this really matter?
4911
- assert d_inner == 2 * d_model
4912
- assert d_inner % head_dim == 0
4923
+ # skip the assertion for FalconH1 Model
4924
+ if self .model_arch != gguf .MODEL_ARCH .FALCON_H1 :
4925
+ assert d_inner == 2 * d_model
4926
+ assert d_inner % head_dim == 0
4913
4927
4914
4928
self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
4915
4929
self .gguf_writer .add_embedding_length (d_model )
@@ -4946,7 +4960,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4946
4960
data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
4947
4961
elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4948
4962
d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4949
- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4963
+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4950
4964
n_group = self .hparams .get ("n_groups" , 1 )
4951
4965
data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4952
4966
@@ -6539,6 +6553,113 @@ def set_gguf_parameters(self):
6539
6553
self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
6540
6554
6541
6555
6556
+ @ModelBase .register ("FalconH1ForCausalLM" )
6557
+ class FalconH1Model (Mamba2Model ):
6558
+ model_arch = gguf .MODEL_ARCH .FALCON_H1
6559
+
6560
+ def __init__ (self , * args , ** kwargs ):
6561
+ # Set the hparam prefixes for Falcon Mamba2
6562
+ self .hparam_prefixes = ["mamba" ]
6563
+
6564
+ # Initialize the base Mamba2Model
6565
+ super ().__init__ (* args , ** kwargs )
6566
+
6567
+ # Use Llama conversion for attention
6568
+ self ._transformer_model_class = LlamaModel
6569
+
6570
+ # n_group and d_inner are used during reshape_tensors for mamaba2
6571
+ self .n_group = self .find_hparam (["n_groups" ])
6572
+ self .d_inner = self .find_hparam (["mamba_d_ssm" ])
6573
+ self .d_head = self .find_hparam (["d_head" ])
6574
+
6575
+ # Initialize any Falcon Mamba2 specific attributes
6576
+ self .has_attention = True # Falcon Mamba2 has attention components
6577
+
6578
+ # Load Falcon-H1 multipliers from hyperparameters
6579
+ self .attention_in_multiplier = self .find_hparam (["attention_in_multiplier" ], optional = True )
6580
+ self .attention_out_multiplier = self .find_hparam (["attention_out_multiplier" ], optional = True )
6581
+ self .ssm_in_multiplier = self .find_hparam (["ssm_in_multiplier" ], optional = True )
6582
+ self .ssm_out_multiplier = self .find_hparam (["ssm_out_multiplier" ], optional = True )
6583
+ self .mlp_multipliers = self .find_hparam (["mlp_multipliers" ], optional = True )
6584
+ self .ssm_multipliers = self .find_hparam (["ssm_multipliers" ], optional = True )
6585
+ self .intermediate_size = self .find_hparam (["intermediate_size" ])
6586
+ self .key_multiplier = self .find_hparam (["key_multiplier" ], optional = True )
6587
+
6588
+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6589
+ prefixed = []
6590
+ for pfx in self .hparam_prefixes :
6591
+ prefixed .extend (
6592
+ "_" .join ([pfx , k ])
6593
+ for k in keys
6594
+ )
6595
+ keys = list (keys ) + prefixed
6596
+ return super ().find_hparam (keys , * args , ** kwargs )
6597
+
6598
+ def set_vocab (self ):
6599
+ self ._set_vocab_gpt2 ()
6600
+
6601
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6602
+ tensors = list (super ().modify_tensors (data_torch , name , bid ))
6603
+ tensor = tensors [0 ][1 ]
6604
+
6605
+ if "down_proj" in name :
6606
+ tensor = tensor * self .mlp_multipliers [1 ]
6607
+ elif "gate_proj" in name :
6608
+ tensor = tensor * self .mlp_multipliers [0 ]
6609
+ elif "k_proj" in name :
6610
+ tensor = tensor * self .key_multiplier * self .attention_in_multiplier
6611
+ elif "q_proj" in name :
6612
+ tensor = tensor * self .attention_in_multiplier
6613
+ elif "v_proj" in name :
6614
+ tensor = tensor * self .attention_in_multiplier
6615
+ elif "o_proj" in name :
6616
+ tensor = tensor * self .attention_out_multiplier
6617
+ elif "out_proj" in name :
6618
+ tensor = tensor * self .ssm_out_multiplier
6619
+ elif "in_proj" in name :
6620
+ tensor = tensor * self .ssm_in_multiplier
6621
+ zxbcdt_multipliers = self .hparams ["ssm_multipliers" ]
6622
+ intermediate_size = self .hparams ["mamba_d_ssm" ]
6623
+ groups_time_state_size = self .hparams ["mamba_n_groups" ] * self .hparams ["mamba_d_state" ]
6624
+ tensor [:intermediate_size , :] *= zxbcdt_multipliers [0 ]
6625
+ tensor [intermediate_size :2 * intermediate_size , :] *= zxbcdt_multipliers [1 ]
6626
+ tensor [2 * intermediate_size :2 * intermediate_size + groups_time_state_size , :] *= zxbcdt_multipliers [2 ]
6627
+ tensor [2 * intermediate_size + groups_time_state_size :2 * intermediate_size + 2 * groups_time_state_size , :] *= zxbcdt_multipliers [3 ]
6628
+ tensor [2 * intermediate_size + 2 * groups_time_state_size :, :] *= zxbcdt_multipliers [4 ]
6629
+ elif "lm_head" in name :
6630
+ tensor = tensor * self .hparams ["lm_head_multiplier" ]
6631
+ elif "embed_tokens" in name :
6632
+ tensor = tensor * self .hparams ["embedding_multiplier" ]
6633
+ elif "mamba.norm" in name :
6634
+ tensor = tensor .reshape (self .n_group , self .d_inner // self .n_group )
6635
+
6636
+ tensors = [(tensors [0 ][0 ], tensor )]
6637
+ return tensors
6638
+
6639
+ def set_gguf_parameters (self ):
6640
+ super ().set_gguf_parameters ()
6641
+
6642
+ ## General Params ##
6643
+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
6644
+ # Override some Mamba2 defaults
6645
+ self .gguf_writer .add_block_count (self .block_count )
6646
+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
6647
+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
6648
+
6649
+ ## Attention params ##
6650
+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ]) # Override value 0 from Mamba2
6651
+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ])
6652
+ self .gguf_writer .add_key_length (self .hparams ["head_dim" ])
6653
+ self .gguf_writer .add_value_length (self .hparams ["head_dim" ])
6654
+
6655
+ ## Validation ##
6656
+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6657
+ assert self .d_inner % self .d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { self .d_head } "
6658
+
6659
+ # Add any other Falcon Mamba2 specific configuration
6660
+ self .gguf_writer .add_rope_freq_base (self .find_hparam (["rope_theta" ]))
6661
+
6662
+
6542
6663
@ModelBase .register ("HunYuanMoEV1ForCausalLM" )
6543
6664
class HunYuanMoEModel (TextModel ):
6544
6665
model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
0 commit comments