Skip to content

Commit 6875697

Browse files
committed
refactor: Collapse Bamba and GraniteMoeHybrid into GraniteHybrid
The only key difference is the use of rope which is now set via rope_finetuned in the hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent fe34d0e commit 6875697

File tree

6 files changed

+374
-455
lines changed

6 files changed

+374
-455
lines changed

convert_hf_to_gguf.py

Lines changed: 97 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -4971,112 +4971,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49714971
yield (new_name, data_torch)
49724972

49734973

4974-
@ModelBase.register("BambaForCausalLM")
4975-
class BambaModel(Mamba2Model):
4976-
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4977-
model_arch = gguf.MODEL_ARCH.BAMBA
4978-
undo_permute = True
4979-
4980-
def __init__(self, *args, **kwargs):
4981-
4982-
# Hybrid mamba models use a prefix for the mamba-specific params.
4983-
# TODO: Extend this if the prefix(es) need to be configurable
4984-
self.hparam_prefixes = ["mamba"]
4985-
4986-
super().__init__(*args, **kwargs)
4987-
4988-
# Use Llama conversion for attention
4989-
self._transformer_model_class: type[TextModel] = LlamaModel
4990-
4991-
# Lists of which layers use ssm vs attention
4992-
self._attn_layers = self.get_attn_layres()
4993-
self._ssm_layers = [
4994-
i for i in range(self.block_count)
4995-
if i not in self._attn_layers
4996-
]
4997-
4998-
# n_group and d_inner are used during reshape_tensors for mamaba2
4999-
self.d_model = self.find_hparam(["hidden_size", "d_model"])
5000-
self.n_group = self.find_hparam(["n_groups"])
5001-
self.d_inner = self.find_hparam(["expand"]) * self.d_model
5002-
5003-
def get_attn_layres(self) -> list[int]:
5004-
attn_layers = self.hparams.get("attn_layer_indices", [])
5005-
if not attn_layers:
5006-
attn_period = self.hparams.get("attn_layer_period")
5007-
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
5008-
attn_offset = self.hparams.get("attn_layer_offset")
5009-
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
5010-
attn_layers = [
5011-
i for i in range(self.block_count)
5012-
if i % attn_period == attn_offset
5013-
]
5014-
return attn_layers
5015-
5016-
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
5017-
prefixed = []
5018-
for pfx in self.hparam_prefixes:
5019-
prefixed.extend(
5020-
"_".join([pfx, k])
5021-
for k in keys
5022-
)
5023-
keys = list(keys) + prefixed
5024-
return super().find_hparam(keys, *args, **kwargs)
5025-
5026-
def set_gguf_parameters(self):
5027-
5028-
## General Params ##
5029-
self.gguf_writer.add_embedding_length(self.d_model)
5030-
self.gguf_writer.add_block_count(self.block_count)
5031-
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
5032-
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
5033-
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
5034-
5035-
## Mamba mixer params ##
5036-
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
5037-
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
5038-
self.gguf_writer.add_ssm_group_count(self.n_group)
5039-
self.gguf_writer.add_ssm_inner_size(self.d_inner)
5040-
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
5041-
# in llama.cpp
5042-
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
5043-
5044-
## Attention params ##
5045-
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
5046-
if rope_dim := self.hparams.get("attn_rotary_emb"):
5047-
self.gguf_writer.add_rope_dimension_count(rope_dim)
5048-
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
5049-
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
5050-
5051-
## Feed Forward Params ##
5052-
self.gguf_writer.add_layer_norm_rms_eps(
5053-
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
5054-
)
5055-
5056-
## Validation ##
5057-
d_head = self.find_hparam(["d_head"], optional=True) or 64
5058-
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
5059-
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
5060-
5061-
def modify_tensors(
5062-
self, data_torch: Tensor, name: str, bid: int | None
5063-
) -> Iterable[tuple[str, Tensor]]:
5064-
5065-
# Determine whether this is a mamaba layer or an attention layer
5066-
if bid in self._ssm_layers:
5067-
for mamba_new_name, data_torch in super().modify_tensors(
5068-
data_torch, name, bid
5069-
):
5070-
yield mamba_new_name, data_torch
5071-
elif bid in self._attn_layers:
5072-
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
5073-
self, data_torch, name, bid
5074-
):
5075-
yield llama_new_name, data_torch
5076-
else:
5077-
yield self.map_tensor_name(name), data_torch
5078-
5079-
50804974
@ModelBase.register("JambaForCausalLM")
50814975
class JambaModel(TextModel):
50824976
model_arch = gguf.MODEL_ARCH.JAMBA
@@ -6579,19 +6473,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65796473
return super().modify_tensors(data_torch, name, bid)
65806474

65816475

6582-
@ModelBase.register("GraniteMoeHybridForCausalLM")
6583-
class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
6584-
"""GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2
6585-
SSM layers"""
6586-
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID
6476+
@ModelBase.register("GraniteMoeHybridForCausalLM", "BambaForCausalLM")
6477+
class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
6478+
"""GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6479+
layers and optionally uses MoE w/ a shared expert"""
6480+
model_arch = gguf.MODEL_ARCH.GRANITE_HYBRID
6481+
undo_permute = True
6482+
6483+
def __init__(self, *args, **kwargs):
6484+
6485+
# Hybrid mamba models use a prefix for the mamba-specific params.
6486+
# TODO: Extend this if the prefix(es) need to be configurable
6487+
self.hparam_prefixes = ["mamba"]
6488+
6489+
super().__init__(*args, **kwargs)
6490+
6491+
# Use Granite conversion for attention
6492+
self._transformer_model_class: type[TextModel] = GraniteModel
6493+
6494+
# Lists of which layers use ssm vs attention
6495+
self._attn_layers = self.get_attn_layres()
6496+
self._ssm_layers = [
6497+
i for i in range(self.block_count)
6498+
if i not in self._attn_layers
6499+
]
6500+
6501+
# n_group and d_inner are used during reshape_tensors for mamaba2
6502+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
6503+
self.n_group = self.find_hparam(["n_groups"])
6504+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
65876505

65886506
def get_attn_layres(self):
6507+
# Explicit list of layer type names
65896508
if layer_types := self.hparams.get("layer_types"):
65906509
return [
65916510
i for i, typ in enumerate(layer_types)
65926511
if typ == "attention"
65936512
]
6594-
return super().get_attn_layres()
6513+
6514+
# Layer types indicated by index or period
6515+
attn_layers = self.hparams.get("attn_layer_indices", [])
6516+
if not attn_layers:
6517+
attn_period = self.hparams.get("attn_layer_period")
6518+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
6519+
attn_offset = self.hparams.get("attn_layer_offset")
6520+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
6521+
attn_layers = [
6522+
i for i in range(self.block_count)
6523+
if i % attn_period == attn_offset
6524+
]
6525+
return attn_layers
6526+
6527+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
6528+
prefixed = []
6529+
for pfx in self.hparam_prefixes:
6530+
prefixed.extend(
6531+
"_".join([pfx, k])
6532+
for k in keys
6533+
)
6534+
keys = list(keys) + prefixed
6535+
return super().find_hparam(keys, *args, **kwargs)
65956536

65966537
def modify_tensors(
65976538
self, data_torch: Tensor, name: str, bid: int | None
@@ -6601,11 +6542,53 @@ def modify_tensors(
66016542
or "shared_mlp" in name
66026543
):
66036544
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
6604-
return super().modify_tensors(data_torch, name, bid)
6545+
6546+
# Determine whether this is a mamaba layer or an attention layer
6547+
if bid in self._ssm_layers:
6548+
return super().modify_tensors(data_torch, name, bid)
6549+
elif bid in self._attn_layers:
6550+
return self._transformer_model_class.modify_tensors(self, data_torch, name, bid)
6551+
return [(self.map_tensor_name(name), data_torch)]
66056552

66066553
def set_gguf_parameters(self):
66076554
GraniteMoeModel.set_gguf_parameters(self)
6608-
BambaModel.set_gguf_parameters(self)
6555+
6556+
## General Params ##
6557+
self.gguf_writer.add_embedding_length(self.d_model)
6558+
self.gguf_writer.add_block_count(self.block_count)
6559+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
6560+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
6561+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
6562+
6563+
## Mamba mixer params ##
6564+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
6565+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
6566+
self.gguf_writer.add_ssm_group_count(self.n_group)
6567+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
6568+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6569+
# in llama.cpp
6570+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
6571+
6572+
## Attention params ##
6573+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
6574+
if rope_dim := self.hparams.get("attn_rotary_emb"):
6575+
self.gguf_writer.add_rope_dimension_count(rope_dim)
6576+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
6577+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
6578+
6579+
## Feed Forward Params ##
6580+
self.gguf_writer.add_layer_norm_rms_eps(
6581+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
6582+
)
6583+
6584+
## If Bamba, use rope, otherwise don't
6585+
use_rope = "BambaForCausalLM" in self.hparams["architectures"]
6586+
self.gguf_writer.add_rope_scaling_finetuned(use_rope)
6587+
6588+
## Validation ##
6589+
d_head = self.find_hparam(["d_head"], optional=True) or 64
6590+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
6591+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
66096592

66106593
def set_vocab(self):
66116594
self.hparams["pad_vocab_size_multiple"] = 8

0 commit comments

Comments
 (0)