Skip to content

Commit 463272d

Browse files
committed
Merge remote-tracking branch 'origin/compilade/mamba2' into GraniteFour
* origin/compilade/mamba2: kv-cache : allow context shift for recurrent models convert : avoid AutoConfig for Mamba and Mamba2 hparams
2 parents c3b7922 + e94f393 commit 463272d

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4243,6 +4243,14 @@ def set_gguf_parameters(self):
42434243
class MambaModel(TextModel):
42444244
model_arch = gguf.MODEL_ARCH.MAMBA
42454245

4246+
def __init__(self, dir_model: Path, *args, **kwargs):
4247+
# Avoid using AutoConfig for hparams
4248+
hparams = kwargs.pop("hparams", None)
4249+
if hparams is None:
4250+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4251+
hparams = json.load(f)
4252+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4253+
42464254
def set_vocab(self):
42474255
vocab_size = self.hparams["vocab_size"]
42484256
# Round vocab size to next multiple of 8
@@ -4321,8 +4329,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43214329
class Mamba2Model(TextModel):
43224330
model_arch = gguf.MODEL_ARCH.MAMBA2
43234331

4324-
def __init__(self, *args, **kwargs):
4325-
super().__init__(*args, **kwargs)
4332+
def __init__(self, dir_model: Path, *args, **kwargs):
4333+
# Avoid using AutoConfig for hparams
4334+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4335+
hparams = kwargs.pop("hparams", None)
4336+
if hparams is None:
4337+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4338+
hparams = json.load(f)
4339+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
43264340
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
43274341
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
43284342
self.n_group = self.hparams.get("n_groups", 1)
@@ -6225,12 +6239,20 @@ def split_str_to_n_bytes(split_str: str) -> int:
62256239
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
62266240
text_config = hparams.get("text_config", {})
62276241
vision_config = hparams.get("vision_config", {})
6228-
arch = hparams["architectures"][0]
6242+
arch = None
6243+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
6244+
arch = arches[0]
6245+
elif "ssm_cfg" in hparams:
6246+
# For non-hf Mamba and Mamba2 models
6247+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
6248+
62296249
# if "architectures" is found in the sub-config, use that instead
62306250
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
62316251
arch = text_config["architectures"][0]
62326252
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
62336253
arch = vision_config["architectures"][0]
6254+
if arch is None:
6255+
raise ValueError("Failed to detect model architecture")
62346256
return arch
62356257

62366258

src/llama-kv-cache.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,8 @@ llama_pos llama_kv_cache_recurrent::get_pos_max() const {
19381938
}
19391939

19401940
bool llama_kv_cache_recurrent::get_can_shift() const {
1941-
return false;
1941+
// shifting is trivial, the recurrent states don't care about the absolute position
1942+
return true;
19421943
}
19431944

19441945
uint32_t llama_kv_cache_recurrent::cell_max() const {

0 commit comments

Comments
 (0)