@@ -4127,6 +4127,14 @@ def set_gguf_parameters(self):
4127
4127
class MambaModel (TextModel ):
4128
4128
model_arch = gguf .MODEL_ARCH .MAMBA
4129
4129
4130
+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4131
+ # Avoid using AutoConfig for hparams
4132
+ hparams = kwargs .pop ("hparams" , None )
4133
+ if hparams is None :
4134
+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4135
+ hparams = json .load (f )
4136
+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4137
+
4130
4138
def set_vocab (self ):
4131
4139
vocab_size = self .hparams ["vocab_size" ]
4132
4140
# Round vocab size to next multiple of 8
@@ -4205,6 +4213,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4205
4213
class Mamba2Model (TextModel ):
4206
4214
model_arch = gguf .MODEL_ARCH .MAMBA2
4207
4215
4216
+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4217
+ # Avoid using AutoConfig for hparams
4218
+ # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4219
+ hparams = kwargs .pop ("hparams" , None )
4220
+ if hparams is None :
4221
+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4222
+ hparams = json .load (f )
4223
+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4224
+
4208
4225
def set_vocab (self ):
4209
4226
vocab_size = self .hparams ["vocab_size" ]
4210
4227
# Round vocab size to next multiple of 16
@@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
5968
5985
hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5969
5986
text_config = hparams .get ("text_config" , {})
5970
5987
vision_config = hparams .get ("vision_config" , {})
5971
- arch = hparams ["architectures" ][0 ]
5988
+ arch = None
5989
+ if (arches := hparams .get ("architectures" )) is not None and len (arches ) > 0 :
5990
+ arch = arches [0 ]
5991
+ elif "ssm_cfg" in hparams :
5992
+ # For non-hf Mamba and Mamba2 models
5993
+ arch = hparams ["ssm_cfg" ].get ("layer" , "Mamba" ) + "ForCausalLM"
5994
+
5972
5995
# if "architectures" is found in the sub-config, use that instead
5973
5996
if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5974
5997
arch = text_config ["architectures" ][0 ]
5975
5998
elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5976
5999
arch = vision_config ["architectures" ][0 ]
6000
+ if arch is None :
6001
+ raise ValueError ("Failed to detect model architecture" )
5977
6002
return arch
5978
6003
5979
6004
0 commit comments