diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d802524bba4a0..8ba9077ee9c9e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -104,9 +104,9 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") - remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_model(remote_hf_model_id) self.tensor_names = set(name for name in remote_tensors.keys()) - for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items(): + for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_model(remote_hf_model_id).items(): yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) self.get_tensors = get_remote_tensors diff --git a/convert_mistral_to_gguf.py b/convert_mistral_to_gguf.py new file mode 100755 index 0000000000000..39c8d3e3dab50 --- /dev/null +++ b/convert_mistral_to_gguf.py @@ -0,0 +1,1118 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import logging +import argparse +import json +import os +import sys +from enum import IntEnum +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + ContextManager, + Iterable, + Iterator, + Sequence, + Type, + cast, +) + +import numpy as np +import torch + +from gguf.constants import MODEL_ARCH, MODEL_ARCH_NAMES +from gguf.vocab import MistralTokenizerType, MistralVocab +from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD + +if TYPE_CHECKING: + from torch import Tensor + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) +import gguf + +logger = logging.getLogger("mistral-to-gguf") + + +###### MODEL DEFINITIONS ###### + + +class SentencePieceTokenTypes(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class ModelType(IntEnum): + TEXT = 1 + MMPROJ = 2 + + +class ModelBase: + dir_model: Path + ftype: gguf.LlamaFileType + fname_out: Path + is_big_endian: bool + endianess: gguf.GGUFEndian + use_temp_file: bool + lazy: bool + hparams: dict[str, Any] + tensor_names: set[str] | None + gguf_writer: gguf.GGUFWriter + model_name: str | None + metadata_override: Path | None + dir_model_card: Path + remote_hf_model_id: str | None + model_arch: MODEL_ARCH + model_type: ModelType + + # subclasses should initialize this! + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__( + self, + dir_model: Path, + ftype: gguf.LlamaFileType, + fname_out: Path, + *, + is_big_endian: bool = False, + use_temp_file: bool = False, + eager: bool = False, + metadata_override: Path | None = None, + model_name: str | None = None, + split_max_tensors: int = 0, + split_max_size: int = 0, + dry_run: bool = False, + small_first_shard: bool = False, + hparams: dict[str, Any] | None = None, + remote_hf_model_id: str | None = None, + ctx: int = 0, + ): + if ( + type(self) is ModelBase + or type(self) is TextModel + or type(self) is MmprojModel + ): + raise TypeError( + f"{type(self).__name__!r} should not be directly instantiated" + ) + + self.ctx = ctx + self.dir_model = dir_model + self.ftype = ftype + self.fname_out = fname_out + self.is_big_endian = is_big_endian + self.endianess = ( + gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + ) + self.use_temp_file = use_temp_file + self.lazy = not eager or (remote_hf_model_id is not None) + self.remote_hf_model_id = remote_hf_model_id + self.vocab = MistralVocab(self.dir_model) + if remote_hf_model_id is not None: + + def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: + logger.info( + f"Using remote model with HuggingFace id: {remote_hf_model_id}" + ) + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_model( + remote_hf_model_id + ) + self.tensor_names = set(name for name in remote_tensors.keys()) + for ( + name, + remote_tensor, + ) in gguf.utility.SafetensorRemote.get_list_tensors_model( + remote_hf_model_id + ).items(): + yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) + + self.get_tensors = get_remote_tensors + + self.hparams = ( + ModelBase.load_hparams(self.dir_model) if hparams is None else hparams + ) + self.tensor_names = None + self.metadata_override = metadata_override + self.model_name = model_name + self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + + # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type + if self.ftype == gguf.LlamaFileType.GUESSED: + _, first_tensor = next(self.get_tensors()) + if first_tensor.dtype == torch.float16: + logger.info( + f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})" + ) + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + else: + logger.info( + f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})" + ) + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + + # Configure GGUF Writer + self.gguf_writer = gguf.GGUFWriter( + path=None, + arch=MODEL_ARCH_NAMES[self.model_arch], + endianess=self.endianess, + use_temp_file=self.use_temp_file, + split_max_tensors=split_max_tensors, + split_max_size=split_max_size, + dry_run=dry_run, + small_first_shard=small_first_shard, + ) + + @classmethod + def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path: + stem, suffix = path.stem, path.suffix + new_name = f"{prefix}{stem}{suffix}" + return path.with_name(new_name) + + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in self.hparams), None) + if key is not None: + return self.hparams[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_names_from_parts: set[str] = set() + + self.tensor_names = tensor_names_from_parts + weight_map: dict[str, str] = {} + + logger.info("gguf: loading 'consolidated.satensors'") + ctx: ContextManager[Any] + from safetensors import safe_open + + ctx = cast( + ContextManager[Any], + safe_open( + self.dir_model / "consolidated.safetensors", + framework="pt", + device="cpu", + ), + ) + + with ctx as model_part: + tensor_names_from_parts.update(model_part.keys()) + + for name in model_part.keys(): + if self.lazy: + data = model_part.get_slice(name) + data = LazyTorchTensor.from_safetensors_slice(data) + else: + data = model_part.get_tensor(name) + yield name, data + + # verify tensor name presence and identify potentially missing files + if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: + missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) + missing_files = sorted( + set(weight_map[n] for n in missing if n in weight_map) + ) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError( + f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}" + ) + else: + raise ValueError( + "Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}" + ) + + def format_tensor_name( + self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight" + ) -> str: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + raise ValueError( + f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}" + ) + name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in name: + assert bid is not None + name = name.format(bid=bid) + return name + suffix + + def match_model_tensor_name( + self, + name: str, + key: gguf.MODEL_TENSOR, + bid: int | None, + suffix: str = ".weight", + ) -> bool: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + return False + key_name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in key_name: + if bid is None: + return False + key_name = key_name.format(bid=bid) + else: + if bid is not None: + return False + return name == (key_name + suffix) + + def map_tensor_name( + self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias") + ) -> str: + new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + return new_name + + def set_gguf_parameters(self): + raise NotImplementedError( + "set_gguf_parameters() must be implemented in subclasses" + ) + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len( + ".weight," + ) + + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith( + (".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq") + ): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data_torch in self.modify_tensors(data_torch, name, bid): + # hard coded for pixtral + if name == "vision_language_adapter.w_in.weight": + assert new_name == "mm.23.weight", new_name + new_name = "mm.1.weight" + elif name == "vision_language_adapter.w_out.weight": + assert new_name == "mm.23.weight", new_name + new_name = "mm.2.weight" + + data = data_torch.numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + + n_dims = len(data.shape) + data_qtype: gguf.GGMLQuantizationType | bool = False + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + if n_dims <= 1 or new_name.endswith("_norm.weight"): + data_qtype = gguf.GGMLQuantizationType.F32 + + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + # Some tensor types are always in float32 + if data_qtype is False and ( + any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + gguf.MODEL_TENSOR.V_ENC_EMBD_POS, + ) + ) + or not new_name.endswith(".weight") + ): + data_qtype = gguf.GGMLQuantizationType.F32 + + if data_qtype is False and any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.TOKEN_EMBD, + gguf.MODEL_TENSOR.OUTPUT, + ) + ): + if self.ftype in ( + gguf.LlamaFileType.MOSTLY_TQ1_0, + gguf.LlamaFileType.MOSTLY_TQ2_0, + ): + # TODO: use Q4_K and Q6_K + data_qtype = gguf.GGMLQuantizationType.F16 + + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) + if isinstance(data_qtype, bool): + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: + data_qtype = gguf.GGMLQuantizationType.TQ1_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: + data_qtype = gguf.GGMLQuantizationType.TQ2_0 + else: + raise ValueError(f"Unknown file type: {self.ftype.name}") + + try: + data = gguf.quants.quantize(data, data_qtype) + except gguf.QuantError as e: + logger.warning("%s, %s", e, "falling back to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + shape = ( + gguf.quant_shape_from_byte_shape(data.shape, data_qtype) + if data.dtype == np.uint8 + else data.shape + ) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info( + f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}" + ) + + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.MODEL) + + def prepare_metadata(self): + total_params, shared_params, expert_params, expert_count = ( + self.gguf_writer.get_total_parameter_count() + ) + + self.metadata = gguf.Metadata.load( + self.metadata_override, self.dir_model_card, self.model_name, total_params + ) + + # If we are using HF model id, set the metadata name to the model id + if self.remote_hf_model_id: + self.metadata.name = self.remote_hf_model_id + + # Fallback to model directory name if metadata name is still missing + if self.metadata.name is None: + self.metadata.name = self.dir_model.name + + # Generate parameter weight class (useful for leader boards) if not yet determined + if self.metadata.size_label is None and total_params > 0: + self.metadata.size_label = gguf.size_label( + total_params, shared_params, expert_params, expert_count + ) + + self.set_type() + + logger.info("Set meta model") + self.metadata.set_gguf_meta_model(self.gguf_writer) + + logger.info("Set model parameters") + self.set_gguf_parameters() + + logger.info("Set model quantization version") + self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + + def write(self): + self.prepare_tensors() + self.prepare_metadata() + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def load_hparams(dir_model: Path): + with open(dir_model / "params.json", "r", encoding="utf-8") as f: + config = json.load(f) + return config + + +class TextModel(ModelBase): + model_type = ModelType.TEXT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "text_config" in self.hparams: + # move the text_config to the root level + self.hparams = {**self.hparams, **self.hparams["text_config"]} + + self.block_count = self.find_hparam( + ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] + ) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + logger.info( + f"Converting tokenizer {self.vocab.tokenizer_type} of size {self.vocab.vocab_size}." + ) + + self.gguf_writer.add_tokenizer_model(self.vocab.gguf_tokenizer_model) + + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in self.vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == self.vocab.vocab_size, ( + f"token count ({len(tokens)}) != vocab size ({self.vocab.vocab_size})" + ) + + if self.vocab.tokenizer_type == MistralTokenizerType.tekken: + self.gguf_writer.add_tokenizer_pre("tekken") + self.gguf_writer.add_token_merges( + self.vocab.extract_vocab_merges_from_model() + ) + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {self.vocab.bos_id}, {self.vocab.eos_id}, {self.vocab.unk_id}, {self.vocab.pad_id}." + ) + + self.gguf_writer.add_bos_token_id(self.vocab.bos_id) + self.gguf_writer.add_eos_token_id(self.vocab.eos_id) + self.gguf_writer.add_unk_token_id(self.vocab.unk_id) + self.gguf_writer.add_pad_token_id(self.vocab.pad_id) + + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_vocab_size(self.vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + def set_vocab_none(self): + logger.info("Skipping tokenizer conversion.") + logger.info("Setting tokenizer to 'none'.") + self.gguf_writer.add_tokenizer_model("none") + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {self.vocab.bos_id}, {self.vocab.eos_id}, {self.vocab.unk_id}, {self.vocab.pad_id}." + ) + self.gguf_writer.add_bos_token_id(self.vocab.bos_id) + self.gguf_writer.add_eos_token_id(self.vocab.eos_id) + self.gguf_writer.add_unk_token_id(self.vocab.unk_id) + self.gguf_writer.add_pad_token_id(self.vocab.pad_id) + + logger.info(f"Setting vocab size to {self.vocab.vocab_size}.") + self.gguf_writer.add_vocab_size(self.vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(False) + self.gguf_writer.add_add_eos_token(False) + + def prepare_metadata(self): + super().prepare_metadata() + + total_params = self.gguf_writer.get_total_parameter_count()[0] + # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' + output_type: str = self.ftype.name.partition("_")[2] + + # Filename Output + if self.fname_out.is_dir(): + # Generate default filename based on model specification and available metadata + fname_default: str = gguf.naming_convention( + self.metadata.name, + self.metadata.basename, + self.metadata.finetune, + self.metadata.version, + self.metadata.size_label, + output_type, + model_type="LoRA" if total_params < 0 else None, + ) + + # Use the default filename + self.fname_out = self.fname_out / f"{fname_default}.gguf" + else: + # Output path is a custom defined templated filename + # Note: `not is_dir()` is used because `.is_file()` will not detect + # file template strings as it doesn't actually exist as a file + + # Process templated file name with the output ftype, useful with the "auto" ftype + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename( + self.fname_out.name, output_type + ) + + logger.info("Set model tokenizer") + self.set_vocab() + + def set_gguf_parameters(self): + self.gguf_writer.add_block_count(self.block_count) + + if self.ctx == 0: + raise ValueError("ctx not passed as argument") + self.gguf_writer.add_context_length(self.ctx) + logger.info(f"gguf: training context length = {self.ctx}") + + if (n_embd := self.find_hparam(["dim"], optional=True)) is not None: + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + if (n_ff := self.find_hparam(["hidden_dim"], optional=True)) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + logger.info(f"gguf: feed forward length = {n_ff}") + + if (n_head := self.find_hparam(["n_heads"], optional=True)) is not None: + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("n_kv_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + + if (f_norm_eps := self.find_hparam(["norm_eps"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + + if (head_dim := self.hparams.get("head_dim")) is not None: + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + +class MmprojModel(ModelBase): + model_type = ModelType.MMPROJ + model_arch = gguf.MODEL_ARCH.MMPROJ + preprocessor_config: dict[str, Any] + global_config: dict[str, Any] + + n_block_keys = ["num_hidden_layers"] + + has_vision_encoder: bool = True + + hparams_vision: dict[str, Any] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + text_config = { + k: v for k, v in self.hparams.items() if k not in ["vision_encoder"] + } + self.n_embd_text = text_config.get("hidden_dim", 0) + assert self.n_embd_text > 0, "n_embd not found in hparams" + + # move vision config to the top level, while preserving the original hparams in global_config + import copy + + self.global_config = copy.deepcopy(self.hparams) + self.hparams_vision = self.get_vision_config() + + self.block_count = self.hparams_vision.get("num_hidden_layers", 0) + assert self.block_count > 0, "num_hidden_layers not found in vision_config" + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def get_vision_config(self) -> dict[str, Any]: + vision_config = self.global_config.get("vision_encoder") + assert vision_config is not None, "vision_config not found in hparams" + return vision_config + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) + + def set_gguf_parameters(self): + self.gguf_writer.add_file_type(self.ftype) + + if not self.has_vision_encoder: + raise ValueError("MmprojModel must have a vision encoder") + + def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_vision is not None + return self._find_param(self.hparams_vision, keys, optional) + + def _find_param( + self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False + ) -> Any: + key = next((k for k in keys if k in obj), None) + if key is not None: + return obj[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + +class MistralModel(TextModel): + model_name = "mistral" + model_arch = MODEL_ARCH.MISTRAL + undo_permute = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + rope_scaling = self.hparams.get("rope_scaling") or {} + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["n_heads"] + n_kv_head = self.hparams.get("n_kv_heads") + is_vision_tensor = any( + name.startswith(prefix) + for prefix in [ + "vision_encoder.", + "vision_language_adapter.", + "patch_merger.", + "pre_mm_projector_norm", + ] + ) + + if is_vision_tensor: + return [] # skip vision tensors + + if self.undo_permute: + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = self.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = self.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + + +class PixtralModel(MmprojModel): + model_name = "mistral" + img_break_tok_id = -1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py + self.hparams["layer_norm_eps"] = self.hparams.get("norm_eps", 1e-5) + self.img_break_tok_id = self.hparams_vision.get("image_break_token_id", -1) + assert self.img_break_tok_id >= 0, ( + "image_break_token_id not found in vision_config" + ) + logger.info(f"Image break token id: {self.img_break_tok_id}") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL) + + self.gguf_writer.add_clip_has_vision_encoder(True) + self.gguf_writer.add_vision_projection_dim(self.n_embd_text) + + # vision config + self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length( + self.find_vparam(["intermediate_size"]) + ) + self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) + self.gguf_writer.add_vision_head_count( + self.find_vparam(["num_attention_heads"]) + ) + + # preprocessor config + self.gguf_writer.add_vision_image_mean( + self.hparams_vision.get("image_mean", DATASET_MEAN) + ) + self.gguf_writer.add_vision_image_std( + self.hparams_vision.get("image_std", DATASET_STD) + ) + + self.gguf_writer.add_vision_attention_layernorm_eps( + self.find_hparam(["layer_norm_eps"]) + ) + self.gguf_writer.add_rope_freq_base(self.find_vparam(["rope_theta"])) + + self.gguf_writer.add_vision_use_silu(True) + + # spatial_merge_size + if self.hparams_vision["mm_projector_id"] == "patch_merge": + self.gguf_writer.add_vision_spatial_merge_size( + self.find_vparam(["spatial_merge_size"]) + ) + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + del bid # unused + n_head = self.hparams_vision["num_attention_heads"] + n_kv_head = n_head + + if any( + name.startswith(prefix) + for prefix in [ + "vision_encoder.", + "vision_language_adapter.", + "patch_merger.", + "pre_mm_projector_norm", + ] + ): + # process vision tensors + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = MistralModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = MistralModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + if self.img_break_tok_id > 0 and "tok_embeddings.weight" in name: + logger.info(f"Extracting [IMG_BREAK] token embedding from {name}") + # for pixtral model, we need to extract the [IMG_BREAK] token embedding + img_break_embd = data_torch[self.img_break_tok_id] + name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK] + return [(self.map_tensor_name(name), img_break_embd)] + + return [] # skip other tensors + + +# tree of lazy tensors +class LazyTorchTensor(gguf.LazyBase): + _tensor_type = torch.Tensor + # to keep the type-checker happy + dtype: torch.dtype + shape: torch.Size + + # only used when converting a torch.Tensor to a np.ndarray + _dtype_map: dict[torch.dtype, type] = { + torch.float16: np.float16, + torch.float32: np.float32, + } + + # used for safetensors slices + # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 + # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 + _dtype_str_map: dict[str, torch.dtype] = { + "F64": torch.float64, + "F32": torch.float32, + "BF16": torch.bfloat16, + "F16": torch.float16, + # "U64": torch.uint64, + "I64": torch.int64, + # "U32": torch.uint32, + "I32": torch.int32, + # "U16": torch.uint16, + "I16": torch.int16, + "U8": torch.uint8, + "I8": torch.int8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } + + def numpy(self) -> gguf.LazyNumpyTensor: + dtype = self._dtype_map[self.dtype] + return gguf.LazyNumpyTensor( + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), + args=(self,), + func=(lambda s: s.numpy()), + ) + + @classmethod + def meta_with_dtype_and_shape( + cls, dtype: torch.dtype, shape: tuple[int, ...] + ) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") + + @classmethod + def from_safetensors_slice(cls, st_slice: Any) -> Tensor: + dtype = cls._dtype_str_map[st_slice.get_dtype()] + shape: tuple[int, ...] = tuple(st_slice.get_shape()) + lazy = cls( + meta=cls.meta_with_dtype_and_shape(dtype, shape), + args=(st_slice,), + func=lambda s: s[:], + ) + return cast(torch.Tensor, lazy) + + @classmethod + def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + dtype = cls._dtype_str_map[remote_tensor.dtype] + shape = remote_tensor.shape + meta = cls.meta_with_dtype_and_shape(dtype, shape) + lazy = cls( + meta=meta, + args=(remote_tensor,), + func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape), + ) + return cast(torch.Tensor, lazy) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.numpy: + return args[0].numpy() + + return cls._wrap_fn(func)(*args, **kwargs) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file" + ) + parser.add_argument( + "--outfile", + type=Path, + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", + ) + parser.add_argument( + "--outtype", + type=str, + choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], + default="bf16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--bigendian", + action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file", + nargs="?", + ) + parser.add_argument( + "--ctx-train", + type=int, + help="Training context size", + required=False, + ) + parser.add_argument( + "--use-temp-file", + action="store_true", + help="use the tempfile library while processing (helpful when running out of memory, process killed)", + ) + parser.add_argument( + "--no-lazy", + action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="name of the model", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--split-max-tensors", + type=int, + default=0, + help="max tensors in each split", + ) + parser.add_argument( + "--split-max-size", + type=str, + default="0", + help="max size per split N(M|G)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="only print out a split plan and exit, without writing any new files", + ) + parser.add_argument( + "--no-tensor-first-split", + action="store_true", + help="do not add tensors to the first split (disabled by default)", + ) + parser.add_argument( + "--metadata", + type=Path, + help="Specify the path for an authorship metadata override file", + ) + parser.add_argument( + "--remote", + action="store_true", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'mistralai/Mistral-Small-3.2-24B-Instruct-2506'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", + ) + parser.add_argument( + "--mmproj", + action="store_true", + help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", + ) + + args = parser.parse_args() + return args + + +def split_str_to_n_bytes(split_str: str) -> int: + if split_str.endswith("K"): + n = int(split_str[:-1]) * 1000 + elif split_str.endswith("M"): + n = int(split_str[:-1]) * 1000 * 1000 + elif split_str.endswith("G"): + n = int(split_str[:-1]) * 1000 * 1000 * 1000 + elif split_str.isnumeric(): + n = int(split_str) + else: + raise ValueError( + f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G" + ) + + if n < 0: + raise ValueError(f"Invalid split size: {split_str}, must be positive") + + return n + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + if args.remote: + from huggingface_hub import snapshot_download + + local_dir = snapshot_download( + repo_id=str(dir_model), + allow_patterns=[ + "LICENSE", + "params.json", + "tekken.json", + "*.md", + "tokenizer.model", + ], + ) + dir_model = Path(local_dir) + logger.info(f"Downloaded config and tokenizer to {local_dir}") + + if not dir_model.is_dir(): + logger.error(f"Error: {args.model} is not a directory") + sys.exit(1) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, + "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + is_split = args.split_max_tensors > 0 or args.split_max_size != "0" + if args.use_temp_file and is_split: + logger.error("Error: Cannot use temp file when splitting") + sys.exit(1) + + if args.outfile is not None: + fname_out = args.outfile + elif args.remote: + # if remote, use the model ID as the output file name + fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf") + else: + fname_out = dir_model + + logger.info(f"Loading model: {dir_model.name}") + + with torch.inference_mode(): + output_type = ftype_map[args.outtype] + hparams = ModelBase.load_hparams(dir_model) + model_class: Type[ModelBase] + if args.mmproj and hparams.get("vision_encoder") is not None: + model_class = PixtralModel + elif args.mmproj: + raise ValueError( + "Multimodal projector export is only supported for vision models" + ) + else: + model_class = MistralModel + logger.info(f"Model architecture: {model_class.__name__}") + + model_instance = model_class( + dir_model, + output_type, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=args.use_temp_file, + eager=args.no_lazy, + metadata_override=args.metadata, + model_name=args.model_name, + split_max_tensors=args.split_max_tensors, + split_max_size=split_str_to_n_bytes(args.split_max_size), + dry_run=args.dry_run, + small_first_shard=args.no_tensor_first_split, + remote_hf_model_id=str(args.model) if args.remote else None, + ctx=args.ctx_train, + ) + + logger.info("Exporting model...") + model_instance.write() + out_path = ( + f"{model_instance.fname_out.parent}{os.sep}" + if is_split + else model_instance.fname_out + ) + logger.info(f"Model successfully exported to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d8afe7696d243..37981d6b7f2c9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -368,6 +368,7 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() LFM2 = auto() DREAM = auto() + MISTRAL = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -685,6 +686,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", + MODEL_ARCH.MISTRAL: "mistral", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2434,6 +2436,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, ], + MODEL_ARCH.MISTRAL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ] # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2a675044f9d99..1de0eebc6dad5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1058,6 +1058,8 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "vision_language_adapter.w_in", # pixtral + "vision_language_adapter.w_out", # pixtral ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1083,7 +1085,8 @@ class TensorNameMap: "vision_tower.vision_model.embeddings.patch_embedding", "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM - "vision_tower.patch_conv", # pixtral + "vision_tower.patch_conv", # pixtral-hf + "vision_encoder.patch_conv", # pixtral "vision_model.patch_embedding.linear", # llama 4 "visual.patch_embed.proj", # qwen2vl ), @@ -1100,7 +1103,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), @@ -1113,7 +1117,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), @@ -1126,7 +1131,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated ), @@ -1135,7 +1141,8 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM - "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral "vision_model.model.layers.{bid}.input_layernorm", # llama4 "visual.blocks.{bid}.norm1", # qwen2vl ), @@ -1146,7 +1153,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral "visual.blocks.{bid}.attn.proj", # qwen2vl ), @@ -1156,7 +1164,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 - "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral "visual.blocks.{bid}.norm2", # qwen2vl ), @@ -1164,14 +1173,16 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 - "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.feed_forward.w3", # pixtral "vision_model.model.layers.{bid}.mlp.fc1", # llama4 "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), MODEL_TENSOR.V_ENC_FFN_GATE: ( - "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl ), @@ -1179,7 +1190,8 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 - "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral-hf + "vision_encoder.transformer.layers.{bid}.feed_forward.w2", # pixtral "vision_model.model.layers.{bid}.mlp.fc2", # llama4 "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl @@ -1195,7 +1207,8 @@ class TensorNameMap: MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", - "vision_tower.ln_pre", # pixtral + "vision_tower.ln_pre", # pixtral-hf + "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 ), @@ -1212,6 +1225,7 @@ class TensorNameMap: MODEL_TENSOR.V_MM_INP_NORM: ( "multi_modal_projector.norm", + "pre_mm_projector_norm", ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( @@ -1267,7 +1281,8 @@ class TensorNameMap: ), MODEL_TENSOR.V_MM_PATCH_MERGER: ( - "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 + "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf + "patch_merger.merging_layer", # mistral ), # audio (mtmd) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 00adcbc937398..8354bd922c1b7 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -111,7 +111,7 @@ class SafetensorRemote: ALIGNMENT = 8 # bytes @classmethod - def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: + def get_list_tensors_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ Get list of tensors from a Hugging Face model repository. @@ -120,9 +120,13 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ # case 1: model has only one single model.safetensor file is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + is_single_file_consolidated = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/consolidated.safetensors", user_agent="convert_mistral_to_gguf") if is_single_file: url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" return cls.get_list_tensors(url) + if is_single_file_consolidated: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/consolidated.safetensors" + return cls.get_list_tensors(url) # case 2: model has multiple files index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" @@ -145,7 +149,11 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: tensors[key] = val return tensors - raise ValueError(f"Model {model_id} does not have any safetensor files") + raise ValueError( + f"No safetensor file has been found for model {model_id}." + "If the repo has safetensor files, make sure the model is public or you have a " + "valid Hugging Face token set in the environment variable HF_TOKEN." + ) @classmethod def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: @@ -234,7 +242,7 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: return response.content[slice(size if size > -1 else None)] @classmethod - def check_file_exist(cls, url: str) -> bool: + def check_file_exist(cls, url: str, user_agent="convert_hf_to_gguf") -> bool: """ Check if a file exists at the given URL. Returns True if the file exists, False otherwise. @@ -247,7 +255,7 @@ def check_file_exist(cls, url: str) -> bool: raise ValueError(f"Invalid URL: {url}") try: - headers = cls._get_request_headers() + headers = cls._get_request_headers(user_agent=user_agent) headers["Range"] = "bytes=0-0" response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) @@ -256,9 +264,9 @@ def check_file_exist(cls, url: str) -> bool: return False @classmethod - def _get_request_headers(cls) -> dict[str, str]: + def _get_request_headers(cls, user_agent="convert_hf_to_gguf") -> dict[str, str]: """Prepare common headers for requests.""" - headers = {"User-Agent": "convert_hf_to_gguf"} + headers = {"User-Agent": user_agent} if os.environ.get("HF_TOKEN"): headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" return headers diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 635fcef35e235..d3caeac36a3d9 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -1,21 +1,46 @@ from __future__ import annotations +from enum import Enum import re import logging import json import os from pathlib import Path -from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable +from typing import ( + Any, + Callable, + Sequence, + Mapping, + Iterable, + Protocol, + ClassVar, + runtime_checkable +) try: from sentencepiece import SentencePieceProcessor except ImportError: SentencePieceProcessor = None +try: + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from mistral_common.tokens.tokenizers.utils import ( + filter_valid_tokenizer_files, + ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) +except ImportError: + _mistral_common_installed = False +else: + _mistral_common_installed = True + import gguf from .gguf_writer import GGUFWriter + logger = logging.getLogger(__name__) @@ -26,7 +51,9 @@ class SpecialVocab: chat_template: str | Sequence[Mapping[str, str]] | None def __init__( - self, path: str | os.PathLike[str], load_merges: bool = False, + self, + path: str | os.PathLike[str], + load_merges: bool = False, special_token_types: Iterable[str] | None = None, n_vocab: int | None = None, ): @@ -39,40 +66,60 @@ def __init__( if special_token_types is not None: self.special_token_types = special_token_types else: - self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask') + self.special_token_types = ( + "bos", + "eos", + "unk", + "sep", + "pad", + "cls", + "mask", + ) self._load(Path(path)) def __repr__(self) -> str: - return ''.format( - len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", + return "".format( + len(self.merges), + self.special_token_ids or "unset", + self.add_special_token or "unset", ) def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: if self.merges: if not quiet: - logger.info(f'Adding {len(self.merges)} merge(s).') + logger.info(f"Adding {len(self.merges)} merge(s).") gw.add_token_merges(self.merges) elif self.load_merges: - logger.warning('Adding merges requested but no merges found, output may be non-functional.') + logger.warning( + "Adding merges requested but no merges found, output may be non-functional." + ) for typ, tokid in self.special_token_ids.items(): - id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) + id_handler: Callable[[int], None] | None = getattr( + gw, f"add_{typ}_token_id", None + ) if id_handler is None: - logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping') + logger.warning( + f"No handler for special token type {typ} with id {tokid} - skipping" + ) continue if not quiet: - logger.info(f'Setting special token type {typ} to {tokid}') + logger.info(f"Setting special token type {typ} to {tokid}") id_handler(tokid) for typ, value in self.add_special_token.items(): - add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) + add_handler: Callable[[bool], None] | None = getattr( + gw, f"add_add_{typ}_token", None + ) if add_handler is None: - logger.warning(f'No handler for add_{typ}_token with value {value} - skipping') + logger.warning( + f"No handler for add_{typ}_token with value {value} - skipping" + ) continue if not quiet: - logger.info(f'Setting add_{typ}_token to {value}') + logger.info(f"Setting add_{typ}_token to {value}") add_handler(value) if self.chat_template is not None: if not quiet: - logger.info(f'Setting chat_template to {self.chat_template}') + logger.info(f"Setting chat_template to {self.chat_template}") gw.add_chat_template(self.chat_template) def _load(self, path: Path) -> None: @@ -82,12 +129,12 @@ def _load(self, path: Path) -> None: self._try_load_merges_txt(path) def _try_load_merges_txt(self, path: Path) -> bool: - merges_file = path / 'merges.txt' + merges_file = path / "merges.txt" if not merges_file.is_file(): return False - with open(merges_file, 'r', encoding = 'utf-8') as fp: - first_line = next(fp, '').strip() - if not first_line.startswith('#'): + with open(merges_file, "r", encoding="utf-8") as fp: + first_line = next(fp, "").strip() + if not first_line.startswith("#"): fp.seek(0) line_num = 0 else: @@ -100,9 +147,11 @@ def _try_load_merges_txt(self, path: Path) -> bool: continue parts = line.split(None, 3) if len(parts) != 2: - logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring') + logger.warning( + f"{merges_file.name}: Line {line_num}: Entry malformed, ignoring" + ) continue - merges.append(f'{parts[0]} {parts[1]}') + merges.append(f"{parts[0]} {parts[1]}") self.merges = merges return True @@ -110,37 +159,45 @@ def _set_special_token(self, typ: str, tid: Any) -> None: if not isinstance(tid, int): return if tid < 0: - raise ValueError(f'invalid value for special token type {typ}: {tid}') + raise ValueError(f"invalid value for special token type {typ}: {tid}") if self.n_vocab is None or tid < self.n_vocab: if typ in self.special_token_ids: return self.special_token_ids[typ] = tid return - logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') + logger.warning( + f"Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping" + ) def _try_load_from_tokenizer_json(self, path: Path) -> bool: tokenizer = None tokenizer_file = path / 'tokenizer.json' if tokenizer_file.is_file(): - with open(tokenizer_file, encoding = 'utf-8') as f: + with open(tokenizer_file, encoding="utf-8") as f: tokenizer = json.load(f) if self.load_merges: - merges = tokenizer.get('model', {}).get('merges') + merges = tokenizer.get("model", {}).get("merges") if isinstance(merges, list) and merges: if isinstance(merges[0], str): self.merges = merges - elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + elif ( + isinstance(merges[0], list) + and len(merges[0]) == 2 + and isinstance(merges[0][0], str) + ): # New format since transformers 4.45 to support spaces in merges # ref: https://github.com/ggml-org/llama.cpp/issues/9692 # TODO: internally store as the new format instead of converting to old - if any(' ' in s for pair in merges for s in pair): - logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + if any(" " in s for pair in merges for s in pair): + logger.warning( + f"Spaces in merges detected, encoding as {chr(ord(' ') + 256)!r}" + ) self.merges = [ - ' '.join( + " ".join( [ # ensure the spaces are properly encoded - ''.join( - chr(ord(c) + 256) if c == ' ' else c + "".join( + chr(ord(c) + 256) if c == " " else c for c in part ) for part in pair @@ -150,7 +207,7 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: ] else: raise ValueError("Unknown tokenizer merges format") - added_tokens = tokenizer.get('added_tokens', {}) + added_tokens = tokenizer.get("added_tokens", {}) else: added_tokens = {} tokenizer_config = None @@ -262,16 +319,18 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: if chat_template is None or isinstance(chat_template, (str, list)): self.chat_template = chat_template else: - logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring') + logger.warning( + f"Bad type for chat_template field in {tokenizer_config_file!r} - ignoring" + ) for typ in self.special_token_types: - add_entry = tokenizer_config.get(f'add_{typ}_token') + add_entry = tokenizer_config.get(f"add_{typ}_token") if isinstance(add_entry, bool): self.add_special_token[typ] = add_entry - entry = tokenizer_config.get(f'{typ}_token') + entry = tokenizer_config.get(f"{typ}_token") if isinstance(entry, str): tc_content = entry elif isinstance(entry, dict): - entry_content = entry.get('content') + entry_content = entry.get("content") if not isinstance(entry_content, str): continue tc_content = entry_content @@ -279,20 +338,24 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: continue # We only need the first match here. maybe_token_id = next( - (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), + ( + atok.get("id") + for atok in added_tokens + if atok.get("content") == tc_content + ), None, ) self._set_special_token(typ, maybe_token_id) return True def _try_load_from_config_json(self, path: Path) -> bool: - config_file = path / 'config.json' + config_file = path / "config.json" if not config_file.is_file(): return False - with open(config_file, encoding = 'utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = json.load(f) for typ in self.special_token_types: - self._set_special_token(typ, config.get(f'{typ}_token_id')) + self._set_special_token(typ, config.get(f"{typ}_token_id")) return True @@ -328,54 +391,59 @@ class BpeVocab(Vocab): def __init__(self, base_path: Path): added_tokens: dict[str, int] = {} - if (fname_tokenizer := base_path / 'vocab.json').exists(): + if (fname_tokenizer := base_path / "vocab.json").exists(): # "slow" tokenizer with open(fname_tokenizer, encoding="utf-8") as f: self.vocab = json.load(f) try: # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. - with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + with open(base_path / "added_tokens.json", encoding="utf-8") as f: added_tokens = json.load(f) except FileNotFoundError: pass else: # "fast" tokenizer - fname_tokenizer = base_path / 'tokenizer.json' + fname_tokenizer = base_path / "tokenizer.json" # if this fails, FileNotFoundError propagates to caller with open(fname_tokenizer, encoding="utf-8") as f: tokenizer_json = json.load(f) - tokenizer_model: dict[str, Any] = tokenizer_json['model'] + tokenizer_model: dict[str, Any] = tokenizer_json["model"] if ( - tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) - or tokenizer_json['decoder']['type'] != 'ByteLevel' + tokenizer_model["type"] != "BPE" + or tokenizer_model.get("byte_fallback", False) + or tokenizer_json["decoder"]["type"] != "ByteLevel" ): - raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') + raise FileNotFoundError("Cannot find GPT-2 BPE tokenizer") self.vocab = tokenizer_model["vocab"] - if (added := tokenizer_json.get('added_tokens')) is not None: + if (added := tokenizer_json.get("added_tokens")) is not None: # Added tokens here can be duplicates of the main vocabulary. - added_tokens = {item['content']: item['id'] - for item in added - if item['content'] not in self.vocab} + added_tokens = { + item["content"]: item["id"] + for item in added + if item["content"] not in self.vocab + } - vocab_size = len(self.vocab) + vocab_size = len(self.vocab) expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " - f"{vocab_size} - {expected_end_id}; got {actual_ids}") + raise ValueError( + f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " + f"{vocab_size} - {expected_end_id}; got {actual_ids}" + ) items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_dict = added_tokens - self.added_tokens_list = [text for (text, idx) in items] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} @@ -405,40 +473,44 @@ def __init__(self, base_path: Path): raise RuntimeError("sentencepiece is not installed") added_tokens: dict[str, int] = {} - if (fname_tokenizer := base_path / 'tokenizer.model').exists(): + if (fname_tokenizer := base_path / "tokenizer.model").exists(): # normal location try: - with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + with open(base_path / "added_tokens.json", encoding="utf-8") as f: added_tokens = json.load(f) except FileNotFoundError: pass - elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): + elif not (fname_tokenizer := base_path.parent / "tokenizer.model").exists(): # not found in alternate location either - raise FileNotFoundError('Cannot find tokenizer.model') + raise FileNotFoundError("Cannot find tokenizer.model") self.sentencepiece_tokenizer = SentencePieceProcessor() self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer)) vocab_size = self.sentencepiece_tokenizer.vocab_size() - new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + new_tokens = { + id: piece for piece, id in added_tokens.items() if id >= vocab_size + } expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) + actual_new_ids = sorted(new_tokens.keys()) if expected_new_ids != actual_new_ids: - raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + raise ValueError( + f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" + ) # Token pieces that were added to the base vocabulary. - self.added_tokens_dict = added_tokens - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_dict = added_tokens + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer for i in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(i) - text = piece.encode("utf-8") + text = piece.encode("utf-8") score: float = tokenizer.GetScore(i) toktype = gguf.TokenType.NORMAL @@ -476,25 +548,27 @@ class LlamaHfVocab(Vocab): name = "hfft" def __init__(self, base_path: Path): - fname_tokenizer = base_path / 'tokenizer.json' + fname_tokenizer = base_path / "tokenizer.json" # if this fails, FileNotFoundError propagates to caller - with open(fname_tokenizer, encoding='utf-8') as f: + with open(fname_tokenizer, encoding="utf-8") as f: tokenizer_json = json.load(f) # pre-check so we know if we need transformers - tokenizer_model: dict[str, Any] = tokenizer_json['model'] + tokenizer_model: dict[str, Any] = tokenizer_json["model"] is_llama3 = ( - tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) - and not tokenizer_model.get('byte_fallback', True) + tokenizer_model["type"] == "BPE" + and tokenizer_model.get("ignore_merges", False) + and not tokenizer_model.get("byte_fallback", True) ) if is_llama3: - raise TypeError('Llama 3 must be converted with BpeVocab') + raise TypeError("Llama 3 must be converted with BpeVocab") if not is_llama3 and ( - tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) - or tokenizer_json['decoder']['type'] != 'Sequence' + tokenizer_model["type"] != "BPE" + or not tokenizer_model.get("byte_fallback", False) + or tokenizer_json["decoder"]["type"] != "Sequence" ): - raise FileNotFoundError('Cannot find Llama BPE tokenizer') + raise FileNotFoundError("Cannot find Llama BPE tokenizer") try: from transformers import AutoTokenizer @@ -516,7 +590,7 @@ def __init__(self, base_path: Path): # Initialize lists and dictionaries for added tokens self.added_tokens_list = [] self.added_tokens_dict = dict() - self.added_tokens_ids = set() + self.added_tokens_ids = set() # Process added tokens for tok, tokidx in sorted( @@ -537,7 +611,7 @@ def __init__(self, base_path: Path): # Set vocabulary sizes self.vocab_size_base = self.tokenizer.vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer @@ -555,17 +629,27 @@ def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: token_text = reverse_vocab[token_id].encode("utf-8") # Yield token text, score, and type - yield token_text, self.get_token_score(token_id), self.get_token_type( - token_id, token_text, self.special_ids # Reuse already stored special IDs + yield ( + token_text, + self.get_token_score(token_id), + self.get_token_type( + token_id, + token_text, + self.special_ids, # Reuse already stored special IDs + ), ) - def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: + def get_token_type( + self, token_id: int, token_text: bytes, special_ids: set[int] + ) -> gguf.TokenType: # Special case for byte tokens - if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): return gguf.TokenType.BYTE # Determine token type based on whether it's a special token - return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + return ( + gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + ) def get_token_score(self, token_id: int) -> float: # Placeholder for actual logic to determine the token's score @@ -575,7 +659,9 @@ def get_token_score(self, token_id: int) -> float: def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: if text in self.specials: - toktype = self.get_token_type(self.specials[text], b'', self.special_ids) + toktype = self.get_token_type( + self.specials[text], b"", self.special_ids + ) score = self.get_token_score(self.specials[text]) else: toktype = gguf.TokenType.USER_DEFINED @@ -592,3 +678,250 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def __repr__(self) -> str: return f"" + + +class MistralTokenizerType(str, Enum): + spm = "spm" + tekken = "tekken" + + +# Copied from Transformers (Apache 2.0) +# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544 + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +class MistralVocab(Vocab): + tokenizer_model = "mistral" + name = "mistral" + + added_tokens_dict: dict[str, int] = {} + added_tokens_list: list[str] = [] + + def __init__(self, base_path: Path): + if not _mistral_common_installed: + raise ImportError( + "To use MistralVocab, please install the `mistral-common` package. " + "You can install it with `pip install mistral-common`." + ) + + # Find the tokenizer files + all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] + valid_tokenizer_files = filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + self.tokenizer = MistralTokenizer.from_file( + base_path / tokenizer_file + ).instruct_tokenizer.tokenizer + self.tokenizer_type = ( + MistralTokenizerType.tekken + if isinstance(self.tokenizer, Tekkenizer) + else MistralTokenizerType.spm + ) + self.vocab_size = self.tokenizer.n_words + self.fname_tokenizer = base_path / tokenizer_file + self._name = ( + "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version + ) + + @property + def tokenizer_name(self) -> str: + return self._name + + @property + def gguf_tokenizer_model(self) -> str: + return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2" + + def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + assert isinstance(self.tokenizer, SentencePieceTokenizer), ( + f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}" + ) + + for i in range(self.tokenizer._model.vocab_size()): + piece = self.tokenizer._model.IdToPiece(i) + text = piece.encode("utf-8") + score: float = self.tokenizer._model.GetScore(i) + + toktype = gguf.TokenType.NORMAL + if self.tokenizer._model.IsUnknown(i): + toktype = gguf.TokenType.UNKNOWN + if self.tokenizer._model.IsControl(i): + toktype = gguf.TokenType.CONTROL + + if self.tokenizer._model.IsUnused(i): + toktype = gguf.TokenType.UNUSED + if self.tokenizer._model.IsByte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def _tekken_tokens(self) -> Iterable[tuple[str, float, gguf.TokenType]]: + assert isinstance(self.tokenizer, Tekkenizer), ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + + byte_encoder = bytes_to_unicode() + for token_id in range(self.tokenizer.num_special_tokens): + token = self.tokenizer.id_to_piece(token_id) + yield token, 0, gguf.TokenType.CONTROL + for token in self.tokenizer._tekken_token2id_nospecial: + yield ( + self.token_bytes_to_string(token, byte_encoder), + 0, + gguf.TokenType.NORMAL, + ) + + def get_token_id(self, token: str) -> int: + if self.tokenizer_type == MistralTokenizerType.spm: + return self.tokenizer._vocab.index(token) + elif self.tokenizer_type == MistralTokenizerType.tekken: + return ( + self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens + ) + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @property + def bos_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_id(self) -> int: + return self.tokenizer.eos_id + + @property + def pad_id(self) -> int: + if self.tokenizer.pad_id == -1: + return self.eos_id + return self.tokenizer.pad_id + + @property + def unk_id(self) -> int: + return self.tokenizer.unk_id + + @property + def bos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.bos_id) + + @property + def eos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.eos_id) + + @property + def pad_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.pad_id) + + @property + def unk_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.unk_id) + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + if self.tokenizer_type == MistralTokenizerType.spm: + yield from self._sentencepiece_tokens() + + elif self.tokenizer_type == MistralTokenizerType.tekken: + yield from self._tekken_tokens() + + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @staticmethod + def token_bytes_to_string(b, byte_encoder): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + def extract_vocab_merges_from_model(self): + # Adapted from Transformers (Apache 2.0) + # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py + assert self.tokenizer_type == MistralTokenizerType.tekken, ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + + mergeable_ranks = self.tokenizer._model._mergeable_ranks + token_bytes_map = { + rank: token_bytes for token_bytes, rank in mergeable_ranks.items() + } + merge_pairs = [] + + # Sort vocab by rank to ensure correct merge order + for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens): + merged_token = token_bytes_map[i] + local = [] + for j in range(1, len(merged_token)): + left = merged_token[:j] + right = merged_token[j:] + if ( + left in mergeable_ranks + and right in mergeable_ranks + and (left + right) in mergeable_ranks + ): + local.append((left, right, i)) + if not local: + raise ValueError( + f"Could not find valid merge for token at rank {i}: {merged_token}" + ) + local = sorted( + local, + key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]), + reverse=False, + ) + merge_pairs.extend(local) + merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False) + + byte_encoder = bytes_to_unicode() + + merge_pairs = [ + [ + self.token_bytes_to_string(val[0], byte_encoder), + self.token_bytes_to_string(val[1], byte_encoder), + ] + for val in merge_pairs + ] + + merges = [ + " ".join( + [ + # ensure the spaces are properly encoded + "".join(chr(ord(c) + 256) if c == " " else c for c in part) + for part in pair + ] + ) + for pair in merge_pairs + ] + + return merges diff --git a/pyproject.toml b/pyproject.toml index 3d71b055a8dbf..69ea98c1dbb8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,4 +42,5 @@ build-backend = "poetry.core.masonry.api" llama-convert-hf-to-gguf = "convert_hf_to_gguf:main" llama-convert-lora-to-gguf = "convert_lora_to_gguf:main" llama-convert-llama-ggml-to-gguf = "convert_llama_ggml_to_gguf:main" +llama-convert-mistral-to-gguf = "convert_mistral_to_gguf:main" llama-ggml-vk-generate-shaders = "ggml_vk_generate_shaders:main" diff --git a/requirements.txt b/requirements.txt index f2a18d62879b4..9120254ca1f49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ -r ./requirements/requirements-convert_hf_to_gguf_update.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt -r ./requirements/requirements-convert_lora_to_gguf.txt +-r ./requirements/requirements-convert_mistral_to_gguf.txt -r ./requirements/requirements-tool_bench.txt diff --git a/requirements/requirements-convert_mistral_to_gguf.txt b/requirements/requirements-convert_mistral_to_gguf.txt new file mode 100644 index 0000000000000..fbe9a4b6d8dd9 --- /dev/null +++ b/requirements/requirements-convert_mistral_to_gguf.txt @@ -0,0 +1,12 @@ +numpy<2.0.0 +gguf>=0.1.0 +protobuf>=4.21.0,<5.0.0 +mistral-common[hf-hub]>=1.8.0 +safetensors>=0.5.3 + +--extra-index-url https://download.pytorch.org/whl/cpu +torch~=2.2.1; platform_machine != "s390x" + +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly +torch>=0.0.0.dev0; platform_machine == "s390x" \ No newline at end of file diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9454d04e53801..dcf9552180492 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -86,6 +86,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, + { LLM_ARCH_MISTRAL, "mistral" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1886,6 +1887,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, } }, + { + LLM_ARCH_MISTRAL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 0ead0d6cdb11b..e3e3abf030dbf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -90,6 +90,7 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, LLM_ARCH_DREAM, + LLM_ARCH_MISTRAL, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cdf1e424294e5..797dc43ad78e9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -570,6 +570,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_MISTRAL: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2018,6 +2019,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MISTRAL: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -16736,6 +16738,7 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_MISTRAL: { llm = std::make_unique(*this, params, gf); } break; @@ -17206,6 +17209,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_MISTRAL: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2