Skip to content

Commit e5102d8

Browse files
committed
Merge branch 'GraniteMoEShared' into GraniteFour
* GraniteMoEShared: fix: Fix the input to the shared experts fix: Cleaner (maybe more correct?) splitting for gate/up feat: First WIP cut at model arch in cpp fix: Split MoE fused tensors for shared experts in conversion feat: hparam and arch plumbing for granitemoeshared feat: Add GGUF conversion for granitemoeshared llama-model : support Qwen2 embedding models and pooling_mode_lasttoken (ggml-org#13245) convert : use correct context length for nomic-embed-text-v2 (ggml-org#13216)
2 parents 59928ec + 97de56d commit e5102d8

File tree

6 files changed

+372
-237
lines changed

6 files changed

+372
-237
lines changed

convert_hf_to_gguf.py

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,12 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
455455

456456

457457
class TextModel(ModelBase):
458+
model_type = ModelType.TEXT
459+
hf_arch: str
460+
458461
def __init__(self, *args, **kwargs):
459462
super().__init__(*args, **kwargs)
463+
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
460464

461465
if "text_config" in self.hparams:
462466
# move the text_config to the root level
@@ -506,7 +510,7 @@ def prepare_metadata(self, vocab_only: bool):
506510
def set_gguf_parameters(self):
507511
self.gguf_writer.add_block_count(self.block_count)
508512

509-
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
513+
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions"], optional=True)) is not None:
510514
self.gguf_writer.add_context_length(n_ctx)
511515
logger.info(f"gguf: context length = {n_ctx}")
512516

@@ -1075,10 +1079,36 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
10751079
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
10761080
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
10771081

1082+
def _try_set_pooling_type(self) -> None:
1083+
# get pooling path
1084+
pooling_path = None
1085+
module_path = self.dir_model / "modules.json"
1086+
if module_path.is_file():
1087+
with open(module_path, encoding="utf-8") as f:
1088+
modules = json.load(f)
1089+
for mod in modules:
1090+
if mod["type"] == "sentence_transformers.models.Pooling":
1091+
pooling_path = mod["path"]
1092+
break
1093+
1094+
# get pooling type
1095+
if pooling_path is not None:
1096+
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
1097+
pooling = json.load(f)
1098+
if pooling["pooling_mode_mean_tokens"]:
1099+
pooling_type = gguf.PoolingType.MEAN
1100+
elif pooling["pooling_mode_cls_token"]:
1101+
pooling_type = gguf.PoolingType.CLS
1102+
elif pooling["pooling_mode_lasttoken"]:
1103+
pooling_type = gguf.PoolingType.LAST
1104+
else:
1105+
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
1106+
self.gguf_writer.add_pooling_type(pooling_type)
1107+
10781108

10791109
class VisionModel(ModelBase):
1110+
model_type = ModelType.VISION
10801111
model_arch = gguf.MODEL_ARCH.CLIP_VISION
1081-
n_text_embd = 0
10821112
preprocessor_config: dict[str, Any]
10831113
global_config: dict[str, Any]
10841114

@@ -2542,7 +2572,7 @@ def set_gguf_parameters(self):
25422572
self.gguf_writer.add_file_type(self.ftype)
25432573

25442574

2545-
@ModelBase.register("Qwen2ForCausalLM")
2575+
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
25462576
class Qwen2Model(TextModel):
25472577
model_arch = gguf.MODEL_ARCH.QWEN2
25482578

@@ -2554,12 +2584,18 @@ def set_vocab(self):
25542584

25552585
def set_gguf_parameters(self):
25562586
super().set_gguf_parameters()
2587+
self._try_set_pooling_type()
25572588
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
25582589
if self.hparams["rope_scaling"].get("type") == "yarn":
25592590
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
25602591
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
25612592
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
25622593

2594+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2595+
if self.hf_arch == "Qwen2Model":
2596+
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
2597+
yield from super().modify_tensors(data_torch, name, bid)
2598+
25632599

25642600
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
25652601
class Qwen2VLModel(TextModel):
@@ -3396,29 +3432,7 @@ def __init__(self, *args, **kwargs):
33963432
def set_gguf_parameters(self):
33973433
super().set_gguf_parameters()
33983434
self.gguf_writer.add_causal_attention(False)
3399-
3400-
# get pooling path
3401-
pooling_path = None
3402-
module_path = self.dir_model / "modules.json"
3403-
if module_path.is_file():
3404-
with open(module_path, encoding="utf-8") as f:
3405-
modules = json.load(f)
3406-
for mod in modules:
3407-
if mod["type"] == "sentence_transformers.models.Pooling":
3408-
pooling_path = mod["path"]
3409-
break
3410-
3411-
# get pooling type
3412-
if pooling_path is not None:
3413-
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
3414-
pooling = json.load(f)
3415-
if pooling["pooling_mode_mean_tokens"]:
3416-
pooling_type = gguf.PoolingType.MEAN
3417-
elif pooling["pooling_mode_cls_token"]:
3418-
pooling_type = gguf.PoolingType.CLS
3419-
else:
3420-
raise NotImplementedError("Only MEAN and CLS pooling types supported")
3421-
self.gguf_writer.add_pooling_type(pooling_type)
3435+
self._try_set_pooling_type()
34223436

34233437
def set_vocab(self):
34243438
tokens, toktypes, tokpre = self.get_vocab_base()
@@ -3627,8 +3641,13 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
36273641
if self._tokenizer_is_xlmroberta:
36283642
self._xlmroberta_tokenizer_init()
36293643

3630-
# the HF config claims n_ctx=8192, but it uses RoPE scaling
3631-
self.hparams["n_ctx"] = 2048
3644+
npos, mtp = self.hparams["n_positions"], self.hparams.get("max_trained_positions", 2048)
3645+
if npos == 8192 and mtp == 2048:
3646+
self.hparams["n_positions"] = 2048 # nomic-embed-text v1 and v1.5 are trained for 2048 tokens.
3647+
elif npos == 2048 and mtp == 2048:
3648+
self.hparams["n_positions"] = 512 # nomic-embed-text-v2-moe is trained for 512 tokens.
3649+
else:
3650+
raise ValueError(f"unrecognized parameters: n_positions={npos}, max_trained_positions={mtp}")
36323651

36333652
assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
36343653

@@ -5727,6 +5746,39 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
57275746
return super().modify_tensors(data_torch, name, bid)
57285747

57295748

5749+
@ModelBase.register("GraniteMoeSharedForCausalLM")
5750+
class GraniteMoeSharedModel(GraniteMoeModel):
5751+
"""Conversion for IBM's GraniteMoeSharedForCausalLM"""
5752+
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_SHARED
5753+
5754+
def set_gguf_parameters(self):
5755+
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
5756+
- shared_intermediate_size
5757+
"""
5758+
super().set_gguf_parameters()
5759+
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
5760+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
5761+
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
5762+
5763+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5764+
"""In modeling_granitemoeshared, the implementation of parallel experts
5765+
is used. This essentially merges w1 and w3 into a single tensor with 2x
5766+
the hidden size that is then split during forward. To keep compatibility
5767+
with existing shared expert support, we pull them apart here.
5768+
"""
5769+
5770+
if name.endswith("shared_mlp.input_linear.weight"):
5771+
ffn_dim = self.hparams["shared_intermediate_size"]
5772+
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
5773+
gate, up = data_torch.split(ffn_dim, dim=-2)
5774+
return [
5775+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
5776+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
5777+
]
5778+
5779+
return super().modify_tensors(data_torch, name, bid)
5780+
5781+
57305782
@ModelBase.register("BailingMoeForCausalLM")
57315783
class BailingMoeModel(TextModel):
57325784
model_arch = gguf.MODEL_ARCH.BAILINGMOE
@@ -6042,8 +6094,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
60426094
return n
60436095

60446096

6045-
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
6046-
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
6097+
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
60476098
text_config = hparams.get("text_config", {})
60486099
vision_config = hparams.get("vision_config", {})
60496100
arch = hparams["architectures"][0]
@@ -6114,7 +6165,8 @@ def main() -> None:
61146165
with torch.inference_mode():
61156166
output_type = ftype_map[args.outtype]
61166167
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
6117-
model_architecture = get_model_architecture(dir_model, model_type)
6168+
hparams = ModelBase.load_hparams(dir_model)
6169+
model_architecture = get_model_architecture(hparams, model_type)
61186170
logger.info(f"Model architecture: {model_architecture}")
61196171
try:
61206172
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)

0 commit comments

Comments
 (0)