Skip to content

Commit 216f598

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent c094df9 commit 216f598

File tree

4 files changed

+155
-10
lines changed

4 files changed

+155
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4479,6 +4479,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
44794479
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
44804480
hparams = json.load(f)
44814481
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4482+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4483+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4484+
self.n_group = self.hparams.get("n_groups", 1)
44824485

44834486
def set_vocab(self):
44844487
vocab_size = self.hparams["vocab_size"]
@@ -4549,10 +4552,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45494552
# (D is also unsqueezed, but for more straightforward broadcast internally)
45504553
data_torch = data_torch.reshape((*data_torch.shape, 1))
45514554
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4552-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4553-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4554-
n_group = self.hparams.get("n_groups", 1)
4555-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4555+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
45564556

45574557
if name.endswith(".A_log"):
45584558
logger.debug("A_log --> A ==> " + new_name)
@@ -4561,6 +4561,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45614561
yield (new_name, data_torch)
45624562

45634563

4564+
@ModelBase.register("BambaForCausalLM")
4565+
class BambaModel(Mamba2Model):
4566+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4567+
model_arch = gguf.MODEL_ARCH.BAMBA
4568+
undo_permute = True
4569+
4570+
def __init__(self, *args, **kwargs):
4571+
4572+
# Hybrid mamba models use a prefix for the mamba-specific params.
4573+
# TODO: Extend this if the prefix(es) need to be configurable
4574+
self.hparam_prefixes = ["mamba"]
4575+
4576+
super().__init__(*args, **kwargs)
4577+
4578+
# Use Llama conversion for attention
4579+
self._transformer_model_class: type[TextModel] = LlamaModel
4580+
4581+
# Lists of which layers use ssm vs attention
4582+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4583+
if not self._attn_layers:
4584+
attn_period = self.hparams.get("attn_layer_period")
4585+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4586+
attn_offset = self.hparams.get("attn_layer_offset")
4587+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4588+
self._attn_layers = [
4589+
i for i in range(self.block_count)
4590+
if i % attn_period == attn_offset
4591+
]
4592+
self._ssm_layers = [
4593+
i for i in range(self.block_count)
4594+
if i not in self._attn_layers
4595+
]
4596+
4597+
# n_group and d_inner are used during reshape_tensors for mamaba2
4598+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4599+
self.n_group = self.find_hparam(["n_groups"])
4600+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4601+
4602+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4603+
prefixed = []
4604+
for pfx in self.hparam_prefixes:
4605+
prefixed.extend(
4606+
"_".join([pfx, k])
4607+
for k in keys
4608+
)
4609+
keys = list(keys) + prefixed
4610+
return super().find_hparam(keys, *args, **kwargs)
4611+
4612+
def set_gguf_parameters(self):
4613+
4614+
## General Params ##
4615+
self.gguf_writer.add_embedding_length(self.d_model)
4616+
self.gguf_writer.add_block_count(self.block_count)
4617+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4618+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4619+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4620+
4621+
## Mamba mixer params ##
4622+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4623+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4624+
self.gguf_writer.add_ssm_group_count(self.n_group)
4625+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4626+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4627+
# in llama.cpp
4628+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4629+
4630+
## Attention params ##
4631+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4632+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4633+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4634+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4635+
4636+
## Feed Forward Params ##
4637+
self.gguf_writer.add_layer_norm_rms_eps(
4638+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4639+
)
4640+
4641+
## Validation ##
4642+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4643+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4644+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4645+
4646+
def modify_tensors(
4647+
self, data_torch: Tensor, name: str, bid: int | None
4648+
) -> Iterable[tuple[str, Tensor]]:
4649+
4650+
# Determine whether this is a mamaba layer or an attention layer
4651+
if bid in self._ssm_layers:
4652+
for mamba_new_name, data_torch in super().modify_tensors(
4653+
data_torch, name, bid
4654+
):
4655+
yield mamba_new_name, data_torch
4656+
elif bid in self._attn_layers:
4657+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4658+
self, data_torch, name, bid
4659+
):
4660+
yield llama_new_name, data_torch
4661+
else:
4662+
yield self.map_tensor_name(name), data_torch
4663+
4664+
45644665
@ModelBase.register("CohereForCausalLM")
45654666
class CommandR2Model(TextModel):
45664667
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -317,6 +320,7 @@ class MODEL_ARCH(IntEnum):
317320
ARWKV7 = auto()
318321
MAMBA = auto()
319322
MAMBA2 = auto()
323+
BAMBA = auto()
320324
XVERSE = auto()
321325
COMMAND_R = auto()
322326
COHERE2 = auto()
@@ -598,6 +602,7 @@ class MODEL_TENSOR(IntEnum):
598602
MODEL_ARCH.ARWKV7: "arwkv7",
599603
MODEL_ARCH.MAMBA: "mamba",
600604
MODEL_ARCH.MAMBA2: "mamba2",
605+
MODEL_ARCH.BAMBA: "bamba",
601606
MODEL_ARCH.XVERSE: "xverse",
602607
MODEL_ARCH.COMMAND_R: "command-r",
603608
MODEL_ARCH.COHERE2: "cohere2",
@@ -1629,6 +1634,31 @@ class MODEL_TENSOR(IntEnum):
16291634
MODEL_TENSOR.SSM_NORM,
16301635
MODEL_TENSOR.SSM_OUT,
16311636
],
1637+
MODEL_ARCH.BAMBA: [
1638+
MODEL_TENSOR.TOKEN_EMBD,
1639+
MODEL_TENSOR.OUTPUT_NORM,
1640+
MODEL_TENSOR.OUTPUT,
1641+
MODEL_TENSOR.ATTN_NORM,
1642+
MODEL_TENSOR.SSM_IN,
1643+
MODEL_TENSOR.SSM_CONV1D,
1644+
MODEL_TENSOR.SSM_DT,
1645+
MODEL_TENSOR.SSM_A,
1646+
MODEL_TENSOR.SSM_D,
1647+
MODEL_TENSOR.SSM_NORM,
1648+
MODEL_TENSOR.SSM_OUT,
1649+
MODEL_TENSOR.ATTN_Q,
1650+
MODEL_TENSOR.ATTN_K,
1651+
MODEL_TENSOR.ATTN_V,
1652+
MODEL_TENSOR.ATTN_OUT,
1653+
MODEL_TENSOR.FFN_NORM,
1654+
MODEL_TENSOR.FFN_GATE,
1655+
MODEL_TENSOR.FFN_DOWN,
1656+
MODEL_TENSOR.FFN_UP,
1657+
MODEL_TENSOR.FFN_GATE_INP,
1658+
MODEL_TENSOR.FFN_GATE_EXP,
1659+
MODEL_TENSOR.FFN_DOWN_EXP,
1660+
MODEL_TENSOR.FFN_UP_EXP,
1661+
],
16321662
MODEL_ARCH.XVERSE: [
16331663
MODEL_TENSOR.TOKEN_EMBD,
16341664
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,9 @@ def add_ssm_group_count(self, value: int) -> None:
848848
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
849849
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
850850

851+
def add_attn_layer_indices(self, values: list[int]) -> None:
852+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
853+
851854
def add_tokenizer_model(self, model: str) -> None:
852855
self.add_string(Keys.Tokenizer.MODEL, model)
853856

gguf-py/gguf/tensor_mapping.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -117,7 +117,7 @@ class TensorNameMap:
117117
"transformer.h.{bid}.input_layernorm", # falcon7b
118118
"h.{bid}.input_layernorm", # bloom
119119
"transformer.h.{bid}.ln_mlp", # falcon40b
120-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
120+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
121121
"layers.{bid}.attention_norm", # llama-pth
122122
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
123123
"model.layers.{bid}.ln1", # yi
@@ -268,7 +268,8 @@ class TensorNameMap:
268268
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
269269
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
270270
"transformer.layers.{bid}.ffn_norm", # openelm
271-
"model.layers.{bid}.post_attention_layernorm", # llama4
271+
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
272+
"model.layers.{bid}.pre_ff_layernorm", # bamba
272273
),
273274

274275
# Post feed-forward norm
@@ -329,7 +330,8 @@ class TensorNameMap:
329330
"model.layers.{bid}.residual_mlp.w3", # arctic
330331
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
331332
"transformer.h.{bid}.mlp.c_fc_1", # exaone
332-
"model.layers.{bid}.feed_forward.up_proj", # llama4
333+
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
334+
"model.layers.{bid}.feed_forward.up_proj", # bamba
333335
),
334336

335337
MODEL_TENSOR.FFN_UP_EXP: (
@@ -366,7 +368,8 @@ class TensorNameMap:
366368
"transformer.h.{bid}.mlp.linear_1", # refact
367369
"model.layers.{bid}.residual_mlp.w1", # arctic
368370
"transformer.h.{bid}.mlp.c_fc_0", # exaone
369-
"model.layers.{bid}.feed_forward.gate_proj", # llama4
371+
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
372+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
370373
),
371374

372375
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -410,7 +413,8 @@ class TensorNameMap:
410413
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
411414
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
412415
"model.layers.h.{bid}.mlp.c_proj", # exaone
413-
"model.layers.{bid}.feed_forward.down_proj", # llama4
416+
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
417+
"model.layers.{bid}.feed_forward.down_proj", # bamba
414418
),
415419

416420
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -464,11 +468,13 @@ class TensorNameMap:
464468
MODEL_TENSOR.SSM_IN: (
465469
"model.layers.{bid}.in_proj",
466470
"backbone.layers.{bid}.mixer.in_proj",
471+
"model.layers.{bid}.mamba.in_proj", # bamba
467472
),
468473

469474
MODEL_TENSOR.SSM_CONV1D: (
470475
"model.layers.{bid}.conv1d",
471476
"backbone.layers.{bid}.mixer.conv1d",
477+
"model.layers.{bid}.mamba.conv1d", # bamba
472478
),
473479

474480
MODEL_TENSOR.SSM_X: (
@@ -479,25 +485,30 @@ class TensorNameMap:
479485
MODEL_TENSOR.SSM_DT: (
480486
"model.layers.{bid}.dt_proj",
481487
"backbone.layers.{bid}.mixer.dt_proj",
488+
"model.layers.{bid}.mamba.dt_proj", # bamba
482489
),
483490

484491
MODEL_TENSOR.SSM_A: (
485492
"model.layers.{bid}.A_log",
486493
"backbone.layers.{bid}.mixer.A_log",
494+
"model.layers.{bid}.mamba.A_log", # bamba
487495
),
488496

489497
MODEL_TENSOR.SSM_D: (
490498
"model.layers.{bid}.D",
491499
"backbone.layers.{bid}.mixer.D",
500+
"model.layers.{bid}.mamba.D", # bamba
492501
),
493502

494503
MODEL_TENSOR.SSM_NORM: (
495504
"backbone.layers.{bid}.mixer.norm", # mamba2
505+
"model.layers.{bid}.mamba.norm", # bamba
496506
),
497507

498508
MODEL_TENSOR.SSM_OUT: (
499509
"model.layers.{bid}.out_proj",
500510
"backbone.layers.{bid}.mixer.out_proj",
511+
"model.layers.{bid}.mamba.out_proj", # bamba
501512
),
502513

503514
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)