Skip to content

Commit c82b240

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent ec37cd5 commit c82b240

File tree

4 files changed

+154
-6
lines changed

4 files changed

+154
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4302,6 +4302,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43024302
class Mamba2Model(TextModel):
43034303
model_arch = gguf.MODEL_ARCH.MAMBA2
43044304

4305+
def __init__(self, *args, **kwargs):
4306+
super().__init__(*args, **kwargs)
4307+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4308+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4309+
self.n_group = self.hparams.get("n_groups", 1)
4310+
43054311
def set_vocab(self):
43064312
vocab_size = self.hparams["vocab_size"]
43074313
# Round vocab size to next multiple of 16
@@ -4371,10 +4377,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43714377
# (D is also unsqueezed, but for more straightforward broadcast internally)
43724378
data_torch = data_torch.reshape((*data_torch.shape, 1))
43734379
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4374-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4375-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4376-
n_group = self.hparams.get("n_groups", 1)
4377-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4380+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
43784381

43794382
if name.endswith(".A_log"):
43804383
logger.debug("A_log --> A ==> " + new_name)
@@ -4383,6 +4386,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43834386
yield (new_name, data_torch)
43844387

43854388

4389+
@ModelBase.register("BambaForCausalLM")
4390+
class BambaModel(Mamba2Model):
4391+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4392+
model_arch = gguf.MODEL_ARCH.BAMBA
4393+
undo_permute = True
4394+
4395+
def __init__(self, *args, **kwargs):
4396+
4397+
# Hybrid mamba models use a prefix for the mamba-specific params.
4398+
# TODO: Extend this if the prefix(es) need to be configurable
4399+
self.hparam_prefixes = ["mamba"]
4400+
4401+
super().__init__(*args, **kwargs)
4402+
4403+
# Use Llama conversion for attention
4404+
self._transformer_model_class: type[TextModel] = LlamaModel
4405+
4406+
# Lists of which layers use ssm vs attention
4407+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4408+
if not self._attn_layers:
4409+
attn_period = self.hparams.get("attn_layer_period")
4410+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4411+
attn_offset = self.hparams.get("attn_layer_offset")
4412+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4413+
self._attn_layers = [
4414+
i for i in range(self.block_count)
4415+
if i % attn_period == attn_offset
4416+
]
4417+
self._ssm_layers = [
4418+
i for i in range(self.block_count)
4419+
if i not in self._attn_layers
4420+
]
4421+
4422+
# n_group and d_inner are used during reshape_tensors for mamaba2
4423+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4424+
self.n_group = self.find_hparam(["n_groups"])
4425+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4426+
4427+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4428+
prefixed = []
4429+
for pfx in self.hparam_prefixes:
4430+
prefixed.extend(
4431+
"_".join([pfx, k])
4432+
for k in keys
4433+
)
4434+
keys = list(keys) + prefixed
4435+
return super().find_hparam(keys, *args, **kwargs)
4436+
4437+
def set_gguf_parameters(self):
4438+
4439+
## General Params ##
4440+
self.gguf_writer.add_embedding_length(self.d_model)
4441+
self.gguf_writer.add_block_count(self.block_count)
4442+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4443+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4444+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4445+
4446+
## Mamba mixer params ##
4447+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4448+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4449+
self.gguf_writer.add_ssm_group_count(self.n_group)
4450+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4451+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4452+
# in llama.cpp
4453+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4454+
4455+
## Attention params ##
4456+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4457+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4458+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4459+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4460+
4461+
## Feed Forward Params ##
4462+
self.gguf_writer.add_layer_norm_rms_eps(
4463+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4464+
)
4465+
4466+
## Validation ##
4467+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4468+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4469+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4470+
4471+
def modify_tensors(
4472+
self, data_torch: Tensor, name: str, bid: int | None
4473+
) -> Iterable[tuple[str, Tensor]]:
4474+
4475+
# Determine whether this is a mamaba layer or an attention layer
4476+
if bid in self._ssm_layers:
4477+
for mamba_new_name, data_torch in super().modify_tensors(
4478+
data_torch, name, bid
4479+
):
4480+
yield mamba_new_name, data_torch
4481+
elif bid in self._attn_layers:
4482+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4483+
self, data_torch, name, bid
4484+
):
4485+
yield llama_new_name, data_torch
4486+
else:
4487+
yield self.map_tensor_name(name), data_torch
4488+
4489+
43864490
@ModelBase.register("CohereForCausalLM")
43874491
class CommandR2Model(TextModel):
43884492
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -300,6 +303,7 @@ class MODEL_ARCH(IntEnum):
300303
ARWKV7 = auto()
301304
MAMBA = auto()
302305
MAMBA2 = auto()
306+
BAMBA = auto()
303307
XVERSE = auto()
304308
COMMAND_R = auto()
305309
COHERE2 = auto()
@@ -560,6 +564,7 @@ class MODEL_TENSOR(IntEnum):
560564
MODEL_ARCH.ARWKV7: "arwkv7",
561565
MODEL_ARCH.MAMBA: "mamba",
562566
MODEL_ARCH.MAMBA2: "mamba2",
567+
MODEL_ARCH.BAMBA: "bamba",
563568
MODEL_ARCH.XVERSE: "xverse",
564569
MODEL_ARCH.COMMAND_R: "command-r",
565570
MODEL_ARCH.COHERE2: "cohere2",
@@ -1548,6 +1553,31 @@ class MODEL_TENSOR(IntEnum):
15481553
MODEL_TENSOR.SSM_NORM,
15491554
MODEL_TENSOR.SSM_OUT,
15501555
],
1556+
MODEL_ARCH.BAMBA: [
1557+
MODEL_TENSOR.TOKEN_EMBD,
1558+
MODEL_TENSOR.OUTPUT_NORM,
1559+
MODEL_TENSOR.OUTPUT,
1560+
MODEL_TENSOR.ATTN_NORM,
1561+
MODEL_TENSOR.SSM_IN,
1562+
MODEL_TENSOR.SSM_CONV1D,
1563+
MODEL_TENSOR.SSM_DT,
1564+
MODEL_TENSOR.SSM_A,
1565+
MODEL_TENSOR.SSM_D,
1566+
MODEL_TENSOR.SSM_NORM,
1567+
MODEL_TENSOR.SSM_OUT,
1568+
MODEL_TENSOR.ATTN_Q,
1569+
MODEL_TENSOR.ATTN_K,
1570+
MODEL_TENSOR.ATTN_V,
1571+
MODEL_TENSOR.ATTN_OUT,
1572+
MODEL_TENSOR.FFN_NORM,
1573+
MODEL_TENSOR.FFN_GATE,
1574+
MODEL_TENSOR.FFN_DOWN,
1575+
MODEL_TENSOR.FFN_UP,
1576+
MODEL_TENSOR.FFN_GATE_INP,
1577+
MODEL_TENSOR.FFN_GATE_EXP,
1578+
MODEL_TENSOR.FFN_DOWN_EXP,
1579+
MODEL_TENSOR.FFN_UP_EXP,
1580+
],
15511581
MODEL_ARCH.XVERSE: [
15521582
MODEL_TENSOR.TOKEN_EMBD,
15531583
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,9 @@ def add_ssm_group_count(self, value: int) -> None:
848848
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
849849
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
850850

851+
def add_attn_layer_indices(self, values: list[int]) -> None:
852+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
853+
851854
def add_tokenizer_model(self, model: str) -> None:
852855
self.add_string(Keys.Tokenizer.MODEL, model)
853856

gguf-py/gguf/tensor_mapping.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -117,7 +117,7 @@ class TensorNameMap:
117117
"transformer.h.{bid}.input_layernorm", # falcon7b
118118
"h.{bid}.input_layernorm", # bloom
119119
"transformer.h.{bid}.ln_mlp", # falcon40b
120-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
120+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
121121
"layers.{bid}.attention_norm", # llama-pth
122122
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
123123
"model.layers.{bid}.ln1", # yi
@@ -269,6 +269,7 @@ class TensorNameMap:
269269
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
270270
"transformer.layers.{bid}.ffn_norm", # openelm
271271
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
272+
"model.layers.{bid}.pre_ff_layernorm", # bamba
272273
),
273274

274275
# Post feed-forward norm
@@ -330,6 +331,7 @@ class TensorNameMap:
330331
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
331332
"transformer.h.{bid}.mlp.c_fc_1", # exaone
332333
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
334+
"model.layers.{bid}.feed_forward.up_proj", # bamba
333335
),
334336

335337
MODEL_TENSOR.FFN_UP_EXP: (
@@ -367,6 +369,7 @@ class TensorNameMap:
367369
"model.layers.{bid}.residual_mlp.w1", # arctic
368370
"transformer.h.{bid}.mlp.c_fc_0", # exaone
369371
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
372+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
370373
),
371374

372375
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -411,6 +414,7 @@ class TensorNameMap:
411414
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
412415
"model.layers.h.{bid}.mlp.c_proj", # exaone
413416
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
417+
"model.layers.{bid}.feed_forward.down_proj", # bamba
414418
),
415419

416420
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -464,11 +468,13 @@ class TensorNameMap:
464468
MODEL_TENSOR.SSM_IN: (
465469
"model.layers.{bid}.in_proj",
466470
"backbone.layers.{bid}.mixer.in_proj",
471+
"model.layers.{bid}.mamba.in_proj", # bamba
467472
),
468473

469474
MODEL_TENSOR.SSM_CONV1D: (
470475
"model.layers.{bid}.conv1d",
471476
"backbone.layers.{bid}.mixer.conv1d",
477+
"model.layers.{bid}.mamba.conv1d", # bamba
472478
),
473479

474480
MODEL_TENSOR.SSM_X: (
@@ -479,25 +485,30 @@ class TensorNameMap:
479485
MODEL_TENSOR.SSM_DT: (
480486
"model.layers.{bid}.dt_proj",
481487
"backbone.layers.{bid}.mixer.dt_proj",
488+
"model.layers.{bid}.mamba.dt_proj", # bamba
482489
),
483490

484491
MODEL_TENSOR.SSM_A: (
485492
"model.layers.{bid}.A_log",
486493
"backbone.layers.{bid}.mixer.A_log",
494+
"model.layers.{bid}.mamba.A_log", # bamba
487495
),
488496

489497
MODEL_TENSOR.SSM_D: (
490498
"model.layers.{bid}.D",
491499
"backbone.layers.{bid}.mixer.D",
500+
"model.layers.{bid}.mamba.D", # bamba
492501
),
493502

494503
MODEL_TENSOR.SSM_NORM: (
495504
"backbone.layers.{bid}.mixer.norm", # mamba2
505+
"model.layers.{bid}.mamba.norm", # bamba
496506
),
497507

498508
MODEL_TENSOR.SSM_OUT: (
499509
"model.layers.{bid}.out_proj",
500510
"backbone.layers.{bid}.mixer.out_proj",
511+
"model.layers.{bid}.mamba.out_proj", # bamba
501512
),
502513

503514
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)