Skip to content

Commit d55b0d0

Browse files
committed
convert : avoid AutoConfig for Mamba and Mamba2 hparams
1 parent 929fe85 commit d55b0d0

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,14 @@ def set_gguf_parameters(self):
41274127
class MambaModel(TextModel):
41284128
model_arch = gguf.MODEL_ARCH.MAMBA
41294129

4130+
def __init__(self, dir_model: Path, *args, **kwargs):
4131+
# Avoid using AutoConfig for hparams
4132+
hparams = kwargs.pop("hparams", None)
4133+
if hparams is None:
4134+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4135+
hparams = json.load(f)
4136+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4137+
41304138
def set_vocab(self):
41314139
vocab_size = self.hparams["vocab_size"]
41324140
# Round vocab size to next multiple of 8
@@ -4205,6 +4213,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42054213
class Mamba2Model(TextModel):
42064214
model_arch = gguf.MODEL_ARCH.MAMBA2
42074215

4216+
def __init__(self, dir_model: Path, *args, **kwargs):
4217+
# Avoid using AutoConfig for hparams
4218+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4219+
hparams = kwargs.pop("hparams", None)
4220+
if hparams is None:
4221+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4222+
hparams = json.load(f)
4223+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4224+
42084225
def set_vocab(self):
42094226
vocab_size = self.hparams["vocab_size"]
42104227
# Round vocab size to next multiple of 16
@@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
59685985
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
59695986
text_config = hparams.get("text_config", {})
59705987
vision_config = hparams.get("vision_config", {})
5971-
arch = hparams["architectures"][0]
5988+
arch = None
5989+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
5990+
arch = arches[0]
5991+
elif "ssm_cfg" in hparams:
5992+
# For non-hf Mamba and Mamba2 models
5993+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
5994+
59725995
# if "architectures" is found in the sub-config, use that instead
59735996
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
59745997
arch = text_config["architectures"][0]
59755998
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
59765999
arch = vision_config["architectures"][0]
6000+
if arch is None:
6001+
raise ValueError("Failed to detect model architecture")
59776002
return arch
59786003

59796004

0 commit comments

Comments
 (0)