Skip to content

Commit 01fa72d

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 830e554 commit 01fa72d

File tree

4 files changed

+152
-7
lines changed

4 files changed

+152
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4723,6 +4723,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
47234723
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
47244724
hparams = json.load(f)
47254725
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4726+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4727+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4728+
self.n_group = self.hparams.get("n_groups", 1)
47264729

47274730
def set_vocab(self):
47284731
vocab_size = self.hparams["vocab_size"]
@@ -4793,10 +4796,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
47934796
# (D is also unsqueezed, but for more straightforward broadcast internally)
47944797
data_torch = data_torch.reshape((*data_torch.shape, 1))
47954798
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4796-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4797-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4798-
n_group = self.hparams.get("n_groups", 1)
4799-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4799+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
48004800

48014801
if name.endswith(".A_log"):
48024802
logger.debug("A_log --> A ==> " + new_name)
@@ -4805,6 +4805,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
48054805
yield (new_name, data_torch)
48064806

48074807

4808+
@ModelBase.register("BambaForCausalLM")
4809+
class BambaModel(Mamba2Model):
4810+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4811+
model_arch = gguf.MODEL_ARCH.BAMBA
4812+
undo_permute = True
4813+
4814+
def __init__(self, *args, **kwargs):
4815+
4816+
# Hybrid mamba models use a prefix for the mamba-specific params.
4817+
# TODO: Extend this if the prefix(es) need to be configurable
4818+
self.hparam_prefixes = ["mamba"]
4819+
4820+
super().__init__(*args, **kwargs)
4821+
4822+
# Use Llama conversion for attention
4823+
self._transformer_model_class: type[TextModel] = LlamaModel
4824+
4825+
# Lists of which layers use ssm vs attention
4826+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4827+
if not self._attn_layers:
4828+
attn_period = self.hparams.get("attn_layer_period")
4829+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4830+
attn_offset = self.hparams.get("attn_layer_offset")
4831+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4832+
self._attn_layers = [
4833+
i for i in range(self.block_count)
4834+
if i % attn_period == attn_offset
4835+
]
4836+
self._ssm_layers = [
4837+
i for i in range(self.block_count)
4838+
if i not in self._attn_layers
4839+
]
4840+
4841+
# n_group and d_inner are used during reshape_tensors for mamaba2
4842+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4843+
self.n_group = self.find_hparam(["n_groups"])
4844+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4845+
4846+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4847+
prefixed = []
4848+
for pfx in self.hparam_prefixes:
4849+
prefixed.extend(
4850+
"_".join([pfx, k])
4851+
for k in keys
4852+
)
4853+
keys = list(keys) + prefixed
4854+
return super().find_hparam(keys, *args, **kwargs)
4855+
4856+
def set_gguf_parameters(self):
4857+
4858+
## General Params ##
4859+
self.gguf_writer.add_embedding_length(self.d_model)
4860+
self.gguf_writer.add_block_count(self.block_count)
4861+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4862+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4863+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4864+
4865+
## Mamba mixer params ##
4866+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4867+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4868+
self.gguf_writer.add_ssm_group_count(self.n_group)
4869+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4870+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4871+
# in llama.cpp
4872+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4873+
4874+
## Attention params ##
4875+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4876+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4877+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4878+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4879+
4880+
## Feed Forward Params ##
4881+
self.gguf_writer.add_layer_norm_rms_eps(
4882+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4883+
)
4884+
4885+
## Validation ##
4886+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4887+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4888+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4889+
4890+
def modify_tensors(
4891+
self, data_torch: Tensor, name: str, bid: int | None
4892+
) -> Iterable[tuple[str, Tensor]]:
4893+
4894+
# Determine whether this is a mamaba layer or an attention layer
4895+
if bid in self._ssm_layers:
4896+
for mamba_new_name, data_torch in super().modify_tensors(
4897+
data_torch, name, bid
4898+
):
4899+
yield mamba_new_name, data_torch
4900+
elif bid in self._attn_layers:
4901+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4902+
self, data_torch, name, bid
4903+
):
4904+
yield llama_new_name, data_torch
4905+
else:
4906+
yield self.map_tensor_name(name), data_torch
4907+
4908+
48084909
@ModelBase.register("CohereForCausalLM")
48094910
class CommandR2Model(TextModel):
48104911
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -321,6 +324,7 @@ class MODEL_ARCH(IntEnum):
321324
ARWKV7 = auto()
322325
MAMBA = auto()
323326
MAMBA2 = auto()
327+
BAMBA = auto()
324328
XVERSE = auto()
325329
COMMAND_R = auto()
326330
COHERE2 = auto()
@@ -606,6 +610,7 @@ class MODEL_TENSOR(IntEnum):
606610
MODEL_ARCH.ARWKV7: "arwkv7",
607611
MODEL_ARCH.MAMBA: "mamba",
608612
MODEL_ARCH.MAMBA2: "mamba2",
613+
MODEL_ARCH.BAMBA: "bamba",
609614
MODEL_ARCH.XVERSE: "xverse",
610615
MODEL_ARCH.COMMAND_R: "command-r",
611616
MODEL_ARCH.COHERE2: "cohere2",
@@ -1654,6 +1659,31 @@ class MODEL_TENSOR(IntEnum):
16541659
MODEL_TENSOR.SSM_NORM,
16551660
MODEL_TENSOR.SSM_OUT,
16561661
],
1662+
MODEL_ARCH.BAMBA: [
1663+
MODEL_TENSOR.TOKEN_EMBD,
1664+
MODEL_TENSOR.OUTPUT_NORM,
1665+
MODEL_TENSOR.OUTPUT,
1666+
MODEL_TENSOR.ATTN_NORM,
1667+
MODEL_TENSOR.SSM_IN,
1668+
MODEL_TENSOR.SSM_CONV1D,
1669+
MODEL_TENSOR.SSM_DT,
1670+
MODEL_TENSOR.SSM_A,
1671+
MODEL_TENSOR.SSM_D,
1672+
MODEL_TENSOR.SSM_NORM,
1673+
MODEL_TENSOR.SSM_OUT,
1674+
MODEL_TENSOR.ATTN_Q,
1675+
MODEL_TENSOR.ATTN_K,
1676+
MODEL_TENSOR.ATTN_V,
1677+
MODEL_TENSOR.ATTN_OUT,
1678+
MODEL_TENSOR.FFN_NORM,
1679+
MODEL_TENSOR.FFN_GATE,
1680+
MODEL_TENSOR.FFN_DOWN,
1681+
MODEL_TENSOR.FFN_UP,
1682+
MODEL_TENSOR.FFN_GATE_INP,
1683+
MODEL_TENSOR.FFN_GATE_EXP,
1684+
MODEL_TENSOR.FFN_DOWN_EXP,
1685+
MODEL_TENSOR.FFN_UP_EXP,
1686+
],
16571687
MODEL_ARCH.XVERSE: [
16581688
MODEL_TENSOR.TOKEN_EMBD,
16591689
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,9 @@ def add_ssm_group_count(self, value: int) -> None:
849849
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
850850
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
851851

852+
def add_attn_layer_indices(self, values: list[int]) -> None:
853+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
854+
852855
def add_tokenizer_model(self, model: str) -> None:
853856
self.add_string(Keys.Tokenizer.MODEL, model)
854857

gguf-py/gguf/tensor_mapping.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -118,7 +118,7 @@ class TensorNameMap:
118118
"transformer.h.{bid}.input_layernorm", # falcon7b
119119
"h.{bid}.input_layernorm", # bloom
120120
"transformer.h.{bid}.ln_mlp", # falcon40b
121-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
121+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
122122
"layers.{bid}.attention_norm", # llama-pth
123123
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
124124
"model.layers.{bid}.ln1", # yi
@@ -281,6 +281,7 @@ class TensorNameMap:
281281
"transformer.layers.{bid}.ffn_norm", # openelm
282282
"model.layers.{bid}.post_attention_layernorm", # llama4
283283
"transformer_encoder.{bid}.ffn_norm", # neobert
284+
"model.layers.{bid}.pre_ff_layernorm", # bamba
284285
),
285286

286287
# Post feed-forward norm
@@ -346,6 +347,7 @@ class TensorNameMap:
346347
"transformer.h.{bid}.mlp.c_fc_1", # exaone
347348
"model.layers.{bid}.feed_forward.up_proj", # llama4
348349
"transformer_encoder.{bid}.ffn.w12", # neobert
350+
"model.layers.{bid}.feed_forward.up_proj", # bamba
349351
),
350352

351353
MODEL_TENSOR.FFN_UP_EXP: (
@@ -382,7 +384,8 @@ class TensorNameMap:
382384
"transformer.h.{bid}.mlp.linear_1", # refact
383385
"model.layers.{bid}.residual_mlp.w1", # arctic
384386
"transformer.h.{bid}.mlp.c_fc_0", # exaone
385-
"model.layers.{bid}.feed_forward.gate_proj", # llama4
387+
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
388+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
386389
),
387390

388391
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -429,6 +432,7 @@ class TensorNameMap:
429432
"model.layers.h.{bid}.mlp.c_proj", # exaone
430433
"model.layers.{bid}.feed_forward.down_proj", # llama4
431434
"transformer_encoder.{bid}.ffn.w3", # neobert
435+
"model.layers.{bid}.feed_forward.down_proj", # bamba
432436
),
433437

434438
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -483,11 +487,13 @@ class TensorNameMap:
483487
MODEL_TENSOR.SSM_IN: (
484488
"model.layers.{bid}.in_proj",
485489
"backbone.layers.{bid}.mixer.in_proj",
490+
"model.layers.{bid}.mamba.in_proj", # bamba
486491
),
487492

488493
MODEL_TENSOR.SSM_CONV1D: (
489494
"model.layers.{bid}.conv1d",
490495
"backbone.layers.{bid}.mixer.conv1d",
496+
"model.layers.{bid}.mamba.conv1d", # bamba
491497
),
492498

493499
MODEL_TENSOR.SSM_X: (
@@ -498,25 +504,30 @@ class TensorNameMap:
498504
MODEL_TENSOR.SSM_DT: (
499505
"model.layers.{bid}.dt_proj",
500506
"backbone.layers.{bid}.mixer.dt_proj",
507+
"model.layers.{bid}.mamba.dt_proj", # bamba
501508
),
502509

503510
MODEL_TENSOR.SSM_A: (
504511
"model.layers.{bid}.A_log",
505512
"backbone.layers.{bid}.mixer.A_log",
513+
"model.layers.{bid}.mamba.A_log", # bamba
506514
),
507515

508516
MODEL_TENSOR.SSM_D: (
509517
"model.layers.{bid}.D",
510518
"backbone.layers.{bid}.mixer.D",
519+
"model.layers.{bid}.mamba.D", # bamba
511520
),
512521

513522
MODEL_TENSOR.SSM_NORM: (
514523
"backbone.layers.{bid}.mixer.norm", # mamba2
524+
"model.layers.{bid}.mamba.norm", # bamba
515525
),
516526

517527
MODEL_TENSOR.SSM_OUT: (
518528
"model.layers.{bid}.out_proj",
519529
"backbone.layers.{bid}.mixer.out_proj",
530+
"model.layers.{bid}.mamba.out_proj", # bamba
520531
),
521532

522533
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)