Skip to content

convert-hf : support bfloat16 conversion #7158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 118 additions & 128 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast

import numpy as np
import torch
Expand Down Expand Up @@ -65,8 +65,8 @@ class Model:
model_arch: gguf.MODEL_ARCH

def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
Expand All @@ -83,6 +83,15 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
if first_tensor.dtype == torch.float16:
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_F16
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16

@classmethod
def __init_subclass__(cls):
Expand Down Expand Up @@ -142,14 +151,27 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")

def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
name: str = gguf.TENSOR_NAMES[key]
if key not in gguf.MODEL_TENSORS[self.model_arch]:
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in name:
assert bid is not None
name = name.format(bid=bid)
return name + suffix

def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
return False
key_name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in key_name:
if bid is None:
return False
key_name = key_name.format(bid=bid)
else:
if bid is not None:
return False
return name == (key_name + suffix)

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
Expand Down Expand Up @@ -215,6 +237,23 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
return False

def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good. This is the only line that bugs me. Not sure about the (64 << 16) for logical OR. I'm knee deep in a dataset, so my mental state is not 100% there right now. Could be nothing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's doing the equivalent of these lines in ggml-impl.h:

https://github.com/ggerganov/llama.cpp/blob/befddd0f15de6efb15d7e7f5b527dfb671f4196f/ggml-impl.h#L82-L85

(64 << 16) came from adapting this while postponing the right shift until after rounding.
(n & 0xffff0000) is to avoid rounding up the NANs by setting the low bits to zero.

# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.int16)

# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)

max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")

for name, data_torch in self.get_tensors():
Expand All @@ -239,35 +278,60 @@ def write_tensors(self):
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
data_qtype: gguf.GGMLQuantizationType | None = None

# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)

# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
extra_f32 = any(cond for cond in (
extra_f32,
n_dims == 1,
new_name.endswith("_norm.weight"),
))

# Some tensor types are always in float32
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
))

# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)

# when both extra_f32 and extra_f16 are False, convert to float32 by default
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
data = data.astype(np.float32)

if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
data = data.astype(np.float16)
extra_f16 = any(cond for cond in (
extra_f16,
(name.endswith(".weight") and n_dims >= 2),
))

if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16

elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16

else: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32

assert data_qtype is not None

# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"

# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

self.gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)

def write(self):
self.write_tensors()
Expand Down Expand Up @@ -2023,12 +2087,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(self.map_tensor_name(name), data_torch)]

def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del new_name, bid, n_dims # unused

# not used with get_rows, must be F32
return name == "embeddings.token_type_embeddings.weight"


@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
Expand Down Expand Up @@ -2281,92 +2339,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter


# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: tuple
_func: Callable[[tuple], Tensor] | None

def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args
self._func = func

@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o

def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((self,) if use_self else ()) + args

meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)

return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn

def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size

# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}

def numpy(self) -> gguf.LazyTensor:
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)

@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...

@overload
@staticmethod
def to_eager(t: tuple) -> tuple: ...

@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")

# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)

@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
return t.detach().to("meta")

@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand All @@ -2377,28 +2383,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)

return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)

# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)

def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)

def __getitem__(self, *args): # bloom falcon refact internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)


def parse_args() -> argparse.Namespace:
Expand All @@ -2417,8 +2403,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16"], default="f16",
help="output format - use f32 for float32, f16 for float16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto-f16"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto-f16 for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -2472,9 +2458,11 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)

ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"auto-f16": gguf.LlamaFileType.GUESSED, # TODO: use a more appropriate "auto" type
}

if args.outfile is not None:
Expand All @@ -2497,6 +2485,8 @@ def main() -> None:
logger.info("Set model tokenizer")
model_instance.set_vocab()

model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION);

if args.vocab_only:
logger.info(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab()
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .constants import *
from .lazy import *
from .gguf_reader import *
from .gguf_writer import *
from .tensor_mapping import *
Expand Down
Loading