Skip to content

Commit 0d80ee7

Browse files
committed
Merge remote-tracking branch 'origin/compilade/mamba2' into mamba2-sync
* origin/compilade/mamba2: (27 commits) ggml-cpu : reorder SVE FMA for consistency with other SIMD arches ggml : fix mamba2 ssm scan when compiled with SVE graph : fix recurrent state copies when avoiding copies kv-cache : allow context shift for recurrent models convert : avoid AutoConfig for Mamba and Mamba2 hparams kv-cache : remove const_cast when setting inputs for s_copy metal : single-user mamba2 inference works metal : add missing args for nb references in ssm_scan_f32_group metal : fix confusion between ; and , convert : fix flake8 lint ggml : avoid multiply by D in GGML_OP_SSM_SCAN ggml : remove unused fast broadcast path in GGML_MUL metal : fix wrong number of tokens per sequence in SSM_SCAN metal : fix SSM_SCAN state head offset metal : add back n_seqs to SSM_SCAN args metal : remove unused arguments for SSM_SCAN metal : use log and exp instead of log1pf and expf in SSM_SCAN metal : fix SSM_SCAN pipeline scope metal : attempt to adapt SSM_SCAN for Mamba-2 llama : avoid redundant state copy for Mamba 1 and 2 ...
2 parents 6253c7c + 0b6f6be commit 0d80ee7

23 files changed

+892
-271
lines changed

convert_hf_to_gguf.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4594,6 +4594,14 @@ def set_gguf_parameters(self):
45944594
class MambaModel(TextModel):
45954595
model_arch = gguf.MODEL_ARCH.MAMBA
45964596

4597+
def __init__(self, dir_model: Path, *args, **kwargs):
4598+
# Avoid using AutoConfig for hparams
4599+
hparams = kwargs.pop("hparams", None)
4600+
if hparams is None:
4601+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4602+
hparams = json.load(f)
4603+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4604+
45974605
def set_vocab(self):
45984606
vocab_size = self.hparams["vocab_size"]
45994607
# Round vocab size to next multiple of 8
@@ -4668,6 +4676,100 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
46684676
return [(new_name, data_torch)]
46694677

46704678

4679+
@ModelBase.register("Mamba2ForCausalLM")
4680+
class Mamba2Model(TextModel):
4681+
model_arch = gguf.MODEL_ARCH.MAMBA2
4682+
4683+
def __init__(self, dir_model: Path, *args, **kwargs):
4684+
# Avoid using AutoConfig for hparams
4685+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4686+
hparams = kwargs.pop("hparams", None)
4687+
if hparams is None:
4688+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4689+
hparams = json.load(f)
4690+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4691+
4692+
def set_vocab(self):
4693+
vocab_size = self.hparams["vocab_size"]
4694+
# Round vocab size to next multiple of 16
4695+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
4696+
# pad using ceiling division
4697+
# ref: https://stackoverflow.com/a/17511341/22827863
4698+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
4699+
self.hparams["vocab_size"] = vocab_size
4700+
4701+
if (self.dir_model / "tokenizer.model").is_file():
4702+
self._set_vocab_sentencepiece()
4703+
elif (self.dir_model / "tokenizer.model.v3").is_file():
4704+
# mamba-codestral
4705+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
4706+
elif (self.dir_model / "tokenizer.json").is_file():
4707+
self._set_vocab_gpt2()
4708+
else:
4709+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
4710+
self._set_vocab_builtin("gpt-neox", vocab_size)
4711+
4712+
def set_gguf_parameters(self):
4713+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4714+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4715+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4716+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4717+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4718+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4719+
4720+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4721+
4722+
# Fail early for models which don't have a block expansion factor of 2
4723+
# TODO: does this really matter?
4724+
assert d_inner == 2 * d_model
4725+
assert d_inner % head_dim == 0
4726+
4727+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4728+
self.gguf_writer.add_embedding_length(d_model)
4729+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
4730+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
4731+
self.gguf_writer.add_block_count(self.block_count)
4732+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4733+
self.gguf_writer.add_ssm_inner_size(d_inner)
4734+
self.gguf_writer.add_ssm_state_size(d_state)
4735+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4736+
self.gguf_writer.add_ssm_group_count(n_group)
4737+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
4738+
self.gguf_writer.add_file_type(self.ftype)
4739+
4740+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4741+
4742+
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
4743+
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
4744+
name = name.removeprefix("model.")
4745+
4746+
if name.endswith(".dt_bias"):
4747+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4748+
4749+
new_name = self.map_tensor_name(name)
4750+
4751+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
4752+
data_torch = data_torch.squeeze()
4753+
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
4754+
gguf.MODEL_TENSOR.SSM_A,
4755+
gguf.MODEL_TENSOR.SSM_D,
4756+
]):
4757+
# unsqueeze A to use similar shape semantics as Mamba-1
4758+
# (D is also unsqueezed, but for more straightforward broadcast internally)
4759+
data_torch = data_torch.reshape((*data_torch.shape, 1))
4760+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4761+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4762+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4763+
n_group = self.hparams.get("n_groups", 1)
4764+
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4765+
4766+
if name.endswith(".A_log"):
4767+
logger.debug("A_log --> A ==> " + new_name)
4768+
data_torch = -torch.exp(data_torch)
4769+
4770+
yield (new_name, data_torch)
4771+
4772+
46714773
@ModelBase.register("CohereForCausalLM")
46724774
class CommandR2Model(TextModel):
46734775
model_arch = gguf.MODEL_ARCH.COMMAND_R
@@ -6407,12 +6509,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
64076509
# maybe we should fallback to text model's arch in that case, since not many models have both
64086510
text_config = hparams.get("text_config", {})
64096511
vision_config = hparams.get("vision_config", {})
6410-
arch = hparams["architectures"][0]
6512+
arch = None
6513+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
6514+
arch = arches[0]
6515+
elif "ssm_cfg" in hparams:
6516+
# For non-hf Mamba and Mamba2 models
6517+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
6518+
64116519
# if "architectures" is found in the sub-config, use that instead
64126520
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
64136521
arch = text_config["architectures"][0]
64146522
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
64156523
arch = vision_config["architectures"][0]
6524+
if arch is None:
6525+
raise ValueError("Failed to detect model architecture")
64166526
return arch
64176527

64186528

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1878,7 +1878,8 @@ extern "C" {
18781878
struct ggml_tensor * dt,
18791879
struct ggml_tensor * A,
18801880
struct ggml_tensor * B,
1881-
struct ggml_tensor * C);
1881+
struct ggml_tensor * C,
1882+
struct ggml_tensor * ids);
18821883

18831884
// partition into non-overlapping windows with padding if needed
18841885
// example:

0 commit comments

Comments
 (0)