@@ -4688,6 +4688,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
4688
4688
with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4689
4689
hparams = json .load (f )
4690
4690
super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4691
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4692
+ self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * self .d_model
4693
+ self .n_group = self .hparams .get ("n_groups" , 1 )
4691
4694
4692
4695
def set_vocab (self ):
4693
4696
vocab_size = self .hparams ["vocab_size" ]
@@ -4758,10 +4761,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4758
4761
# (D is also unsqueezed, but for more straightforward broadcast internally)
4759
4762
data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
4760
4763
elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4761
- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4762
- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4763
- n_group = self .hparams .get ("n_groups" , 1 )
4764
- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4764
+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
4765
4765
4766
4766
if name .endswith (".A_log" ):
4767
4767
logger .debug ("A_log --> A ==> " + new_name )
@@ -4770,6 +4770,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4770
4770
yield (new_name , data_torch )
4771
4771
4772
4772
4773
+ @ModelBase .register ("BambaForCausalLM" )
4774
+ class BambaModel (Mamba2Model ):
4775
+ """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4776
+ model_arch = gguf .MODEL_ARCH .BAMBA
4777
+ undo_permute = True
4778
+
4779
+ def __init__ (self , * args , ** kwargs ):
4780
+
4781
+ # Hybrid mamba models use a prefix for the mamba-specific params.
4782
+ # TODO: Extend this if the prefix(es) need to be configurable
4783
+ self .hparam_prefixes = ["mamba" ]
4784
+
4785
+ super ().__init__ (* args , ** kwargs )
4786
+
4787
+ # Use Llama conversion for attention
4788
+ self ._transformer_model_class : type [TextModel ] = LlamaModel
4789
+
4790
+ # Lists of which layers use ssm vs attention
4791
+ self ._attn_layers = self .hparams .get ("attn_layer_indices" , [])
4792
+ if not self ._attn_layers :
4793
+ attn_period = self .hparams .get ("attn_layer_period" )
4794
+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
4795
+ attn_offset = self .hparams .get ("attn_layer_offset" )
4796
+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
4797
+ self ._attn_layers = [
4798
+ i for i in range (self .block_count )
4799
+ if i % attn_period == attn_offset
4800
+ ]
4801
+ self ._ssm_layers = [
4802
+ i for i in range (self .block_count )
4803
+ if i not in self ._attn_layers
4804
+ ]
4805
+
4806
+ # n_group and d_inner are used during reshape_tensors for mamaba2
4807
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
4808
+ self .n_group = self .find_hparam (["n_groups" ])
4809
+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
4810
+
4811
+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
4812
+ prefixed = []
4813
+ for pfx in self .hparam_prefixes :
4814
+ prefixed .extend (
4815
+ "_" .join ([pfx , k ])
4816
+ for k in keys
4817
+ )
4818
+ keys = list (keys ) + prefixed
4819
+ return super ().find_hparam (keys , * args , ** kwargs )
4820
+
4821
+ def set_gguf_parameters (self ):
4822
+
4823
+ ## General Params ##
4824
+ self .gguf_writer .add_embedding_length (self .d_model )
4825
+ self .gguf_writer .add_block_count (self .block_count )
4826
+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
4827
+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
4828
+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
4829
+
4830
+ ## Mamba mixer params ##
4831
+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
4832
+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
4833
+ self .gguf_writer .add_ssm_group_count (self .n_group )
4834
+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
4835
+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4836
+ # in llama.cpp
4837
+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
4838
+
4839
+ ## Attention params ##
4840
+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
4841
+ self .gguf_writer .add_rope_dimension_count (self .hparams ["attn_rotary_emb" ])
4842
+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
4843
+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
4844
+
4845
+ ## Feed Forward Params ##
4846
+ self .gguf_writer .add_layer_norm_rms_eps (
4847
+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
4848
+ )
4849
+
4850
+ ## Validation ##
4851
+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
4852
+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
4853
+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
4854
+
4855
+ def modify_tensors (
4856
+ self , data_torch : Tensor , name : str , bid : int | None
4857
+ ) -> Iterable [tuple [str , Tensor ]]:
4858
+
4859
+ # Determine whether this is a mamaba layer or an attention layer
4860
+ if bid in self ._ssm_layers :
4861
+ for mamba_new_name , data_torch in super ().modify_tensors (
4862
+ data_torch , name , bid
4863
+ ):
4864
+ yield mamba_new_name , data_torch
4865
+ elif bid in self ._attn_layers :
4866
+ for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
4867
+ self , data_torch , name , bid
4868
+ ):
4869
+ yield llama_new_name , data_torch
4870
+ else :
4871
+ yield self .map_tensor_name (name ), data_torch
4872
+
4873
+
4773
4874
@ModelBase .register ("CohereForCausalLM" )
4774
4875
class CommandR2Model (TextModel ):
4775
4876
model_arch = gguf .MODEL_ARCH .COMMAND_R
0 commit comments