@@ -295,6 +295,7 @@ def prepare_tensors(self):
295
295
gguf .MODEL_TENSOR .FFN_GATE_INP ,
296
296
gguf .MODEL_TENSOR .POS_EMBD ,
297
297
gguf .MODEL_TENSOR .TOKEN_TYPES ,
298
+ gguf .MODEL_TENSOR .SSM_CONV1D ,
298
299
)
299
300
)
300
301
or not name .endswith (".weight" )
@@ -590,6 +591,15 @@ def get_vocab_base_pre(self, tokenizer) -> str:
590
591
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249" :
591
592
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
592
593
res = "smollm"
594
+ if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7" :
595
+ # ref: https://huggingface.co/bigscience/bloom
596
+ res = "bloom"
597
+ if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21" :
598
+ # ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
599
+ res = "gpt3-finnish"
600
+ if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae" :
601
+ # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
602
+ res = "exaone"
593
603
594
604
if res is None :
595
605
logger .warning ("\n " )
@@ -893,7 +903,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
893
903
return tensors
894
904
895
905
896
- @Model .register ("BloomForCausalLM" )
906
+ @Model .register ("BloomForCausalLM" , "BloomModel" )
897
907
class BloomModel (Model ):
898
908
model_arch = gguf .MODEL_ARCH .BLOOM
899
909
@@ -2702,7 +2712,7 @@ class StarCoder2Model(Model):
2702
2712
model_arch = gguf .MODEL_ARCH .STARCODER2
2703
2713
2704
2714
2705
- @Model .register ("MambaForCausalLM" , "MambaLMHeadModel" )
2715
+ @Model .register ("MambaForCausalLM" , "MambaLMHeadModel" , "FalconMambaForCausalLM" )
2706
2716
class MambaModel (Model ):
2707
2717
model_arch = gguf .MODEL_ARCH .MAMBA
2708
2718
@@ -2733,20 +2743,24 @@ def set_gguf_parameters(self):
2733
2743
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
2734
2744
dt_rank = self .find_hparam (["time_step_rank" , "dt_rank" ], optional = True ) or - (d_model // - 16 )
2735
2745
rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
2736
-
2746
+ use_dt_b_c_norm = False
2747
+ # For falconmamba we do apply RMS norm on B / DT and C layers
2748
+ if self .find_hparam (["model_type" ], optional = True ) in ("falcon_mamba" ,):
2749
+ use_dt_b_c_norm = True
2737
2750
# Fail early for models which don't have a block expansion factor of 2
2738
2751
assert d_inner == 2 * d_model
2739
2752
2740
2753
self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
2741
2754
self .gguf_writer .add_embedding_length (d_model )
2742
2755
self .gguf_writer .add_feed_forward_length (0 ) # unused, but seemingly required when loading
2743
2756
self .gguf_writer .add_head_count (0 ) # unused, but seemingly required when loading
2744
- self .gguf_writer .add_block_count (self .hparams [ "n_layer" ] )
2757
+ self .gguf_writer .add_block_count (self .block_count )
2745
2758
self .gguf_writer .add_ssm_conv_kernel (d_conv )
2746
2759
self .gguf_writer .add_ssm_inner_size (d_inner )
2747
2760
self .gguf_writer .add_ssm_state_size (d_state )
2748
2761
self .gguf_writer .add_ssm_time_step_rank (dt_rank )
2749
2762
self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
2763
+ self .gguf_writer .add_ssm_dt_b_c_rms (use_dt_b_c_norm ) # For classic Mamba we don't apply rms norm on B / DT layers
2750
2764
self .gguf_writer .add_file_type (self .ftype )
2751
2765
2752
2766
_tok_embd = None
@@ -2773,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
2773
2787
2774
2788
return [(new_name , data_torch )]
2775
2789
2776
- def tensor_force_quant (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> gguf .GGMLQuantizationType | bool :
2777
- if bid is not None and new_name in (
2778
- self .format_tensor_name (
2779
- n , bid , ".weight" if name .endswith (".weight" ) else ""
2780
- )
2781
- for n in [
2782
- gguf .MODEL_TENSOR .SSM_CONV1D ,
2783
- gguf .MODEL_TENSOR .SSM_X ,
2784
- gguf .MODEL_TENSOR .SSM_DT ,
2785
- gguf .MODEL_TENSOR .SSM_A ,
2786
- gguf .MODEL_TENSOR .SSM_D ,
2787
- ]
2788
- ):
2789
- return gguf .GGMLQuantizationType .F32
2790
-
2791
- return super ().tensor_force_quant (name , new_name , bid , n_dims )
2792
-
2793
2790
2794
2791
@Model .register ("CohereForCausalLM" )
2795
2792
class CommandR2Model (Model ):
@@ -3734,8 +3731,120 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
3734
3731
name = name .removeprefix ("transformer." )
3735
3732
return [(self .map_tensor_name (name ), data_torch )]
3736
3733
3737
- ###### CONVERSION LOGIC ######
3738
3734
3735
+ @Model .register ("NemotronForCausalLM" )
3736
+ class NemotronModel (Model ):
3737
+ model_arch = gguf .MODEL_ARCH .NEMOTRON
3738
+
3739
+ def set_vocab (self ):
3740
+ self ._set_vocab_sentencepiece ()
3741
+ self .gguf_writer .add_pad_token_id (0 )
3742
+ self .gguf_writer .add_unk_token_id (1 )
3743
+
3744
+ def set_gguf_parameters (self ):
3745
+ super ().set_gguf_parameters ()
3746
+ hparams = self .hparams
3747
+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
3748
+
3749
+ f_norm_eps = self .find_hparam (["layer_norm_eps" , "layer_norm_epsilon" , "norm_epsilon" , "norm_eps" ])
3750
+ self .gguf_writer .add_layer_norm_eps (f_norm_eps )
3751
+
3752
+ # * Partial RoPE
3753
+ rot_pct = self .find_hparam (["partial_rotary_factor" , "rope_pct" , "rope_percent" ])
3754
+ n_embd = self .find_hparam (["hidden_size" , "n_embd" ])
3755
+ n_head = self .find_hparam (["num_attention_heads" , "n_head" ])
3756
+ self .gguf_writer .add_rope_dimension_count (int (rot_pct * n_embd ) // n_head )
3757
+
3758
+ # * RopeScaling for Nemotron
3759
+ if "rope_scaling" not in self .hparams or self .hparams ["rope_scaling" ] is None :
3760
+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
3761
+ else :
3762
+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
3763
+ self .gguf_writer .add_rope_scaling_factor (self .hparams ["factor" ])
3764
+
3765
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3766
+ # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
3767
+ # model.layers.{l}.input_layernorm.weight
3768
+ # model.layers.{l}.post_attention_layernorm.weight
3769
+ # model.norm.weight
3770
+ if name .endswith ("norm.weight" ):
3771
+ data_torch = data_torch + 1
3772
+
3773
+ return [(self .map_tensor_name (name ), data_torch )]
3774
+
3775
+
3776
+ @Model .register ("ExaoneForCausalLM" )
3777
+ class ExaoneModel (Model ):
3778
+ model_arch = gguf .MODEL_ARCH .EXAONE
3779
+
3780
+ def set_gguf_parameters (self ):
3781
+ hparams = self .hparams
3782
+
3783
+ assert (hparams ["activation_function" ] == "silu" )
3784
+
3785
+ max_position_embeddings = hparams ["max_position_embeddings" ]
3786
+ embed_dim = hparams ["hidden_size" ]
3787
+ num_heads = hparams ["num_attention_heads" ]
3788
+ num_kv_heads = hparams .get ("num_key_value_heads" , num_heads )
3789
+ layer_norm_eps = hparams ["layer_norm_epsilon" ]
3790
+ intermediate_size = hparams ["intermediate_size" ] if "intermediate_size" in hparams else 4 * embed_dim
3791
+ num_layers = hparams ["num_layers" ]
3792
+ # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
3793
+ # attention_dropout_rate = hparams["attention_dropout"]
3794
+ # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
3795
+ # embed_dropout_rate = hparams["embed_dropout"]
3796
+ self .gguf_writer .add_embedding_length (embed_dim )
3797
+ self .gguf_writer .add_head_count (num_heads )
3798
+ self .gguf_writer .add_head_count_kv (num_kv_heads )
3799
+ self .gguf_writer .add_context_length (max_position_embeddings )
3800
+ self .gguf_writer .add_layer_norm_rms_eps (layer_norm_eps )
3801
+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3802
+ self .gguf_writer .add_block_count (num_layers )
3803
+ self .gguf_writer .add_file_type (self .ftype )
3804
+
3805
+ if (rope_theta := self .hparams .get ("rope_theta" )) is not None :
3806
+ self .gguf_writer .add_rope_freq_base (rope_theta )
3807
+ rotary_factor = self .find_hparam (["partial_rotary_factor" , "rope_pct" ], optional = True )
3808
+ rotary_factor = rotary_factor if rotary_factor is not None else 1.0
3809
+ self .gguf_writer .add_rope_dimension_count (int (rotary_factor * (hparams ["hidden_size" ] // hparams ["num_attention_heads" ])))
3810
+ if hparams .get ("rope_scaling" ) is not None and "factor" in hparams ["rope_scaling" ]:
3811
+ if hparams ["rope_scaling" ].get ("type" ) == "linear" :
3812
+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
3813
+ self .gguf_writer .add_rope_scaling_factor (hparams ["rope_scaling" ]["factor" ])
3814
+
3815
+ def prepare_tensors (self ):
3816
+ if rope_scaling := self .find_hparam (["rope_scaling" ], optional = True ):
3817
+ if rope_scaling .get ("rope_type" , '' ).lower () == "llama3" :
3818
+ base = self .hparams .get ("rope_theta" , 10000.0 )
3819
+ dim = self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
3820
+ freqs = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim ))
3821
+
3822
+ factor = rope_scaling .get ("factor" , 8.0 )
3823
+ low_freq_factor = rope_scaling .get ("low_freq_factor" , 1.0 )
3824
+ high_freq_factor = rope_scaling .get ("high_freq_factor" , 4.0 )
3825
+ old_context_len = self .hparams .get ("original_max_position_embeddings" , 8192 )
3826
+
3827
+ low_freq_wavelen = old_context_len / low_freq_factor
3828
+ high_freq_wavelen = old_context_len / high_freq_factor
3829
+ assert low_freq_wavelen != high_freq_wavelen
3830
+
3831
+ rope_factors = []
3832
+ for freq in freqs :
3833
+ wavelen = 2 * math .pi / freq
3834
+ if wavelen < high_freq_wavelen :
3835
+ rope_factors .append (1 )
3836
+ elif wavelen > low_freq_wavelen :
3837
+ rope_factors .append (factor )
3838
+ else :
3839
+ smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
3840
+ rope_factors .append (1 / ((1 - smooth ) / factor + smooth ))
3841
+
3842
+ self .gguf_writer .add_tensor (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), np .array (rope_factors , dtype = np .float32 ))
3843
+
3844
+ super ().prepare_tensors ()
3845
+
3846
+
3847
+ ###### CONVERSION LOGIC ######
3739
3848
3740
3849
# tree of lazy tensors
3741
3850
class LazyTorchTensor (gguf .LazyBase ):
0 commit comments