@@ -4243,6 +4243,14 @@ def set_gguf_parameters(self):
4243
4243
class MambaModel (TextModel ):
4244
4244
model_arch = gguf .MODEL_ARCH .MAMBA
4245
4245
4246
+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4247
+ # Avoid using AutoConfig for hparams
4248
+ hparams = kwargs .pop ("hparams" , None )
4249
+ if hparams is None :
4250
+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4251
+ hparams = json .load (f )
4252
+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4253
+
4246
4254
def set_vocab (self ):
4247
4255
vocab_size = self .hparams ["vocab_size" ]
4248
4256
# Round vocab size to next multiple of 8
@@ -4321,8 +4329,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4321
4329
class Mamba2Model (TextModel ):
4322
4330
model_arch = gguf .MODEL_ARCH .MAMBA2
4323
4331
4324
- def __init__ (self , * args , ** kwargs ):
4325
- super ().__init__ (* args , ** kwargs )
4332
+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4333
+ # Avoid using AutoConfig for hparams
4334
+ # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4335
+ hparams = kwargs .pop ("hparams" , None )
4336
+ if hparams is None :
4337
+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4338
+ hparams = json .load (f )
4339
+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4326
4340
self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4327
4341
self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4328
4342
self .n_group = self .hparams .get ("n_groups" , 1 )
@@ -6225,12 +6239,20 @@ def split_str_to_n_bytes(split_str: str) -> int:
6225
6239
def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
6226
6240
text_config = hparams .get ("text_config" , {})
6227
6241
vision_config = hparams .get ("vision_config" , {})
6228
- arch = hparams ["architectures" ][0 ]
6242
+ arch = None
6243
+ if (arches := hparams .get ("architectures" )) is not None and len (arches ) > 0 :
6244
+ arch = arches [0 ]
6245
+ elif "ssm_cfg" in hparams :
6246
+ # For non-hf Mamba and Mamba2 models
6247
+ arch = hparams ["ssm_cfg" ].get ("layer" , "Mamba" ) + "ForCausalLM"
6248
+
6229
6249
# if "architectures" is found in the sub-config, use that instead
6230
6250
if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
6231
6251
arch = text_config ["architectures" ][0 ]
6232
6252
elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
6233
6253
arch = vision_config ["architectures" ][0 ]
6254
+ if arch is None :
6255
+ raise ValueError ("Failed to detect model architecture" )
6234
6256
return arch
6235
6257
6236
6258
0 commit comments