@@ -4971,112 +4971,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4971
4971
yield (new_name , data_torch )
4972
4972
4973
4973
4974
- @ModelBase .register ("BambaForCausalLM" )
4975
- class BambaModel (Mamba2Model ):
4976
- """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4977
- model_arch = gguf .MODEL_ARCH .BAMBA
4978
- undo_permute = True
4979
-
4980
- def __init__ (self , * args , ** kwargs ):
4981
-
4982
- # Hybrid mamba models use a prefix for the mamba-specific params.
4983
- # TODO: Extend this if the prefix(es) need to be configurable
4984
- self .hparam_prefixes = ["mamba" ]
4985
-
4986
- super ().__init__ (* args , ** kwargs )
4987
-
4988
- # Use Llama conversion for attention
4989
- self ._transformer_model_class : type [TextModel ] = LlamaModel
4990
-
4991
- # Lists of which layers use ssm vs attention
4992
- self ._attn_layers = self .get_attn_layres ()
4993
- self ._ssm_layers = [
4994
- i for i in range (self .block_count )
4995
- if i not in self ._attn_layers
4996
- ]
4997
-
4998
- # n_group and d_inner are used during reshape_tensors for mamaba2
4999
- self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
5000
- self .n_group = self .find_hparam (["n_groups" ])
5001
- self .d_inner = self .find_hparam (["expand" ]) * self .d_model
5002
-
5003
- def get_attn_layres (self ) -> list [int ]:
5004
- attn_layers = self .hparams .get ("attn_layer_indices" , [])
5005
- if not attn_layers :
5006
- attn_period = self .hparams .get ("attn_layer_period" )
5007
- assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
5008
- attn_offset = self .hparams .get ("attn_layer_offset" )
5009
- assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
5010
- attn_layers = [
5011
- i for i in range (self .block_count )
5012
- if i % attn_period == attn_offset
5013
- ]
5014
- return attn_layers
5015
-
5016
- def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
5017
- prefixed = []
5018
- for pfx in self .hparam_prefixes :
5019
- prefixed .extend (
5020
- "_" .join ([pfx , k ])
5021
- for k in keys
5022
- )
5023
- keys = list (keys ) + prefixed
5024
- return super ().find_hparam (keys , * args , ** kwargs )
5025
-
5026
- def set_gguf_parameters (self ):
5027
-
5028
- ## General Params ##
5029
- self .gguf_writer .add_embedding_length (self .d_model )
5030
- self .gguf_writer .add_block_count (self .block_count )
5031
- self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
5032
- self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
5033
- self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
5034
-
5035
- ## Mamba mixer params ##
5036
- self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
5037
- self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
5038
- self .gguf_writer .add_ssm_group_count (self .n_group )
5039
- self .gguf_writer .add_ssm_inner_size (self .d_inner )
5040
- # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
5041
- # in llama.cpp
5042
- self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
5043
-
5044
- ## Attention params ##
5045
- self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
5046
- if rope_dim := self .hparams .get ("attn_rotary_emb" ):
5047
- self .gguf_writer .add_rope_dimension_count (rope_dim )
5048
- self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
5049
- self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
5050
-
5051
- ## Feed Forward Params ##
5052
- self .gguf_writer .add_layer_norm_rms_eps (
5053
- self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
5054
- )
5055
-
5056
- ## Validation ##
5057
- d_head = self .find_hparam (["d_head" ], optional = True ) or 64
5058
- assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
5059
- assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
5060
-
5061
- def modify_tensors (
5062
- self , data_torch : Tensor , name : str , bid : int | None
5063
- ) -> Iterable [tuple [str , Tensor ]]:
5064
-
5065
- # Determine whether this is a mamaba layer or an attention layer
5066
- if bid in self ._ssm_layers :
5067
- for mamba_new_name , data_torch in super ().modify_tensors (
5068
- data_torch , name , bid
5069
- ):
5070
- yield mamba_new_name , data_torch
5071
- elif bid in self ._attn_layers :
5072
- for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
5073
- self , data_torch , name , bid
5074
- ):
5075
- yield llama_new_name , data_torch
5076
- else :
5077
- yield self .map_tensor_name (name ), data_torch
5078
-
5079
-
5080
4974
@ModelBase .register ("JambaForCausalLM" )
5081
4975
class JambaModel (TextModel ):
5082
4976
model_arch = gguf .MODEL_ARCH .JAMBA
@@ -6579,19 +6473,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
6579
6473
return super ().modify_tensors (data_torch , name , bid )
6580
6474
6581
6475
6582
- @ModelBase .register ("GraniteMoeHybridForCausalLM" )
6583
- class GraniteMoeHybridModel (BambaModel , GraniteMoeModel ):
6584
- """GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2
6585
- SSM layers"""
6586
- model_arch = gguf .MODEL_ARCH .GRANITE_MOE_HYBRID
6476
+ @ModelBase .register ("GraniteMoeHybridForCausalLM" , "BambaForCausalLM" )
6477
+ class GraniteHybridModel (Mamba2Model , GraniteMoeModel ):
6478
+ """GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6479
+ layers and optionally uses MoE w/ a shared expert"""
6480
+ model_arch = gguf .MODEL_ARCH .GRANITE_HYBRID
6481
+ undo_permute = True
6482
+
6483
+ def __init__ (self , * args , ** kwargs ):
6484
+
6485
+ # Hybrid mamba models use a prefix for the mamba-specific params.
6486
+ # TODO: Extend this if the prefix(es) need to be configurable
6487
+ self .hparam_prefixes = ["mamba" ]
6488
+
6489
+ super ().__init__ (* args , ** kwargs )
6490
+
6491
+ # Use Granite conversion for attention
6492
+ self ._transformer_model_class : type [TextModel ] = GraniteModel
6493
+
6494
+ # Lists of which layers use ssm vs attention
6495
+ self ._attn_layers = self .get_attn_layres ()
6496
+ self ._ssm_layers = [
6497
+ i for i in range (self .block_count )
6498
+ if i not in self ._attn_layers
6499
+ ]
6500
+
6501
+ # n_group and d_inner are used during reshape_tensors for mamaba2
6502
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
6503
+ self .n_group = self .find_hparam (["n_groups" ])
6504
+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
6587
6505
6588
6506
def get_attn_layres (self ):
6507
+ # Explicit list of layer type names
6589
6508
if layer_types := self .hparams .get ("layer_types" ):
6590
6509
return [
6591
6510
i for i , typ in enumerate (layer_types )
6592
6511
if typ == "attention"
6593
6512
]
6594
- return super ().get_attn_layres ()
6513
+
6514
+ # Layer types indicated by index or period
6515
+ attn_layers = self .hparams .get ("attn_layer_indices" , [])
6516
+ if not attn_layers :
6517
+ attn_period = self .hparams .get ("attn_layer_period" )
6518
+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
6519
+ attn_offset = self .hparams .get ("attn_layer_offset" )
6520
+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
6521
+ attn_layers = [
6522
+ i for i in range (self .block_count )
6523
+ if i % attn_period == attn_offset
6524
+ ]
6525
+ return attn_layers
6526
+
6527
+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6528
+ prefixed = []
6529
+ for pfx in self .hparam_prefixes :
6530
+ prefixed .extend (
6531
+ "_" .join ([pfx , k ])
6532
+ for k in keys
6533
+ )
6534
+ keys = list (keys ) + prefixed
6535
+ return super ().find_hparam (keys , * args , ** kwargs )
6595
6536
6596
6537
def modify_tensors (
6597
6538
self , data_torch : Tensor , name : str , bid : int | None
@@ -6601,11 +6542,53 @@ def modify_tensors(
6601
6542
or "shared_mlp" in name
6602
6543
):
6603
6544
return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6604
- return super ().modify_tensors (data_torch , name , bid )
6545
+
6546
+ # Determine whether this is a mamaba layer or an attention layer
6547
+ if bid in self ._ssm_layers :
6548
+ return super ().modify_tensors (data_torch , name , bid )
6549
+ elif bid in self ._attn_layers :
6550
+ return self ._transformer_model_class .modify_tensors (self , data_torch , name , bid )
6551
+ return [(self .map_tensor_name (name ), data_torch )]
6605
6552
6606
6553
def set_gguf_parameters (self ):
6607
6554
GraniteMoeModel .set_gguf_parameters (self )
6608
- BambaModel .set_gguf_parameters (self )
6555
+
6556
+ ## General Params ##
6557
+ self .gguf_writer .add_embedding_length (self .d_model )
6558
+ self .gguf_writer .add_block_count (self .block_count )
6559
+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
6560
+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
6561
+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
6562
+
6563
+ ## Mamba mixer params ##
6564
+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
6565
+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
6566
+ self .gguf_writer .add_ssm_group_count (self .n_group )
6567
+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
6568
+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6569
+ # in llama.cpp
6570
+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
6571
+
6572
+ ## Attention params ##
6573
+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
6574
+ if rope_dim := self .hparams .get ("attn_rotary_emb" ):
6575
+ self .gguf_writer .add_rope_dimension_count (rope_dim )
6576
+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
6577
+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
6578
+
6579
+ ## Feed Forward Params ##
6580
+ self .gguf_writer .add_layer_norm_rms_eps (
6581
+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
6582
+ )
6583
+
6584
+ ## If Bamba, use rope, otherwise don't
6585
+ use_rope = "BambaForCausalLM" in self .hparams ["architectures" ]
6586
+ self .gguf_writer .add_rope_scaling_finetuned (use_rope )
6587
+
6588
+ ## Validation ##
6589
+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
6590
+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6591
+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
6609
6592
6610
6593
def set_vocab (self ):
6611
6594
self .hparams ["pad_vocab_size_multiple" ] = 8
0 commit comments