@@ -4479,6 +4479,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
4479
4479
with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4480
4480
hparams = json .load (f )
4481
4481
super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4482
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4483
+ self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * self .d_model
4484
+ self .n_group = self .hparams .get ("n_groups" , 1 )
4482
4485
4483
4486
def set_vocab (self ):
4484
4487
vocab_size = self .hparams ["vocab_size" ]
@@ -4549,10 +4552,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4549
4552
# (D is also unsqueezed, but for more straightforward broadcast internally)
4550
4553
data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
4551
4554
elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4552
- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4553
- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4554
- n_group = self .hparams .get ("n_groups" , 1 )
4555
- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4555
+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
4556
4556
4557
4557
if name .endswith (".A_log" ):
4558
4558
logger .debug ("A_log --> A ==> " + new_name )
@@ -4561,6 +4561,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4561
4561
yield (new_name , data_torch )
4562
4562
4563
4563
4564
+ @ModelBase .register ("BambaForCausalLM" )
4565
+ class BambaModel (Mamba2Model ):
4566
+ """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4567
+ model_arch = gguf .MODEL_ARCH .BAMBA
4568
+ undo_permute = True
4569
+
4570
+ def __init__ (self , * args , ** kwargs ):
4571
+
4572
+ # Hybrid mamba models use a prefix for the mamba-specific params.
4573
+ # TODO: Extend this if the prefix(es) need to be configurable
4574
+ self .hparam_prefixes = ["mamba" ]
4575
+
4576
+ super ().__init__ (* args , ** kwargs )
4577
+
4578
+ # Use Llama conversion for attention
4579
+ self ._transformer_model_class : type [TextModel ] = LlamaModel
4580
+
4581
+ # Lists of which layers use ssm vs attention
4582
+ self ._attn_layers = self .hparams .get ("attn_layer_indices" , [])
4583
+ if not self ._attn_layers :
4584
+ attn_period = self .hparams .get ("attn_layer_period" )
4585
+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
4586
+ attn_offset = self .hparams .get ("attn_layer_offset" )
4587
+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
4588
+ self ._attn_layers = [
4589
+ i for i in range (self .block_count )
4590
+ if i % attn_period == attn_offset
4591
+ ]
4592
+ self ._ssm_layers = [
4593
+ i for i in range (self .block_count )
4594
+ if i not in self ._attn_layers
4595
+ ]
4596
+
4597
+ # n_group and d_inner are used during reshape_tensors for mamaba2
4598
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
4599
+ self .n_group = self .find_hparam (["n_groups" ])
4600
+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
4601
+
4602
+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
4603
+ prefixed = []
4604
+ for pfx in self .hparam_prefixes :
4605
+ prefixed .extend (
4606
+ "_" .join ([pfx , k ])
4607
+ for k in keys
4608
+ )
4609
+ keys = list (keys ) + prefixed
4610
+ return super ().find_hparam (keys , * args , ** kwargs )
4611
+
4612
+ def set_gguf_parameters (self ):
4613
+
4614
+ ## General Params ##
4615
+ self .gguf_writer .add_embedding_length (self .d_model )
4616
+ self .gguf_writer .add_block_count (self .block_count )
4617
+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
4618
+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
4619
+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
4620
+
4621
+ ## Mamba mixer params ##
4622
+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
4623
+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
4624
+ self .gguf_writer .add_ssm_group_count (self .n_group )
4625
+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
4626
+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4627
+ # in llama.cpp
4628
+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
4629
+
4630
+ ## Attention params ##
4631
+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
4632
+ self .gguf_writer .add_rope_dimension_count (self .hparams ["attn_rotary_emb" ])
4633
+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
4634
+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
4635
+
4636
+ ## Feed Forward Params ##
4637
+ self .gguf_writer .add_layer_norm_rms_eps (
4638
+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
4639
+ )
4640
+
4641
+ ## Validation ##
4642
+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
4643
+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
4644
+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
4645
+
4646
+ def modify_tensors (
4647
+ self , data_torch : Tensor , name : str , bid : int | None
4648
+ ) -> Iterable [tuple [str , Tensor ]]:
4649
+
4650
+ # Determine whether this is a mamaba layer or an attention layer
4651
+ if bid in self ._ssm_layers :
4652
+ for mamba_new_name , data_torch in super ().modify_tensors (
4653
+ data_torch , name , bid
4654
+ ):
4655
+ yield mamba_new_name , data_torch
4656
+ elif bid in self ._attn_layers :
4657
+ for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
4658
+ self , data_torch , name , bid
4659
+ ):
4660
+ yield llama_new_name , data_torch
4661
+ else :
4662
+ yield self .map_tensor_name (name ), data_torch
4663
+
4664
+
4564
4665
@ModelBase .register ("CohereForCausalLM" )
4565
4666
class CommandR2Model (TextModel ):
4566
4667
model_arch = gguf .MODEL_ARCH .COMMAND_R
0 commit comments