Skip to content

convert-hf : support bfloat16 conversion #7158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 73 additions & 117 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast

import numpy as np
import torch
Expand Down Expand Up @@ -65,8 +65,8 @@ class Model:
model_arch: gguf.MODEL_ARCH

def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
Expand Down Expand Up @@ -215,6 +215,23 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
return False

def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, n | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n

# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)

max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")

for name, data_torch in self.get_tensors():
Expand All @@ -239,10 +256,7 @@ def write_tensors(self):
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
data_qtype: gguf.GGMLQuantizationType | None = None

# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
Expand All @@ -254,20 +268,33 @@ def write_tensors(self):
# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)

# when both extra_f32 and extra_f16 are False, convert to float32 by default
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
data = data.astype(np.float32)
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16

elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16

if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
data = data.astype(np.float16)
else: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32

assert data_qtype is not None

# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"

# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

self.gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)

def write(self):
self.write_tensors()
Expand Down Expand Up @@ -2281,92 +2308,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter


# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: tuple
_func: Callable[[tuple], Tensor] | None

def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args
self._func = func

@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o

def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((self,) if use_self else ()) + args

meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)

return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn

def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size

# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}

def numpy(self) -> gguf.LazyTensor:
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)

@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...

@overload
@staticmethod
def to_eager(t: tuple) -> tuple: ...

@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")

# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)

@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
return t.detach().to("meta")

@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand All @@ -2377,28 +2352,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)

return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)

# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)

def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)

def __getitem__(self, *args): # bloom falcon refact internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)


def parse_args() -> argparse.Namespace:
Expand All @@ -2417,8 +2372,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16"], default="f16",
help="output format - use f32 for float32, f16 for float16",
"--outtype", type=str, choices=["f32", "f16", "bf16"], default="f16",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given most models come in bf16 wouldn't it make sense to set it as default?

Copy link
Collaborator Author

@compilade compilade May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion to bf16 is slightly slower and uses a bit more RAM than f16 conversion, due to the lack of native Numpy support, so I didn't change the default.

I'll see if I can auto-detect whether the model contains bf16 tensors (but it will most likely be too complicated). Otherwise, it does make sense to set bf16 as default if it's widely used.

My concern with bf16 as default is that f16 -> bf16 is more lossy than bf16 -> f16, since 3 bits of the mantissa are always lost in f16 -> bf16, while bf16 -> f16 only turns some very-close-to-zero values into zero, and big values get turned to inf (but such big values are usually not in model weights, see #7075 (comment)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only have rudimentary CPU-only bf16 support in ggml, so f16 is better for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(left comment on wrong account)

I'll see if I can auto-detect whether the model contains bf16 tensors (but it will most likely be too complicated). Otherwise, it does make sense to set bf16 as default if it's widely used.

@compilade could we not attempt to read from config.json? it should have a torch_dtype in it

Ggerganov, when you say CPU-only, I assume you're referring to inference, since all conversion and quantization is currently CPU-only?

Copy link
Collaborator Author

@compilade compilade May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we not attempt to read from config.json? it should have a torch_dtype in it

@bartowski1182 yes, but not all models define that field, so I think a second guess based on the type of the first tensor type in the model will sometimes be necessary.

⚠️ And also some models say one thing in config.json and use another type in the model files. For example, https://huggingface.co/pansophic/rocket-3B has F16 tensors, but defines torch_dtype as bfloat16 in config.json

Would you still like some kind of --outtype auto-f16 based on the model content even if f16 is kept as the default --outtype otherwise? (due to (slightly) faster conversion, and more complete backend support)

Copy link
Collaborator Author

@compilade compilade May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why bf16 is only used for CPU? Will there be GPU support in the future?

@htdung167 This PR is only about conversion, and the convert script always has been CPU-only.
Inference (text generation) with bf16 was worked on by @jart in #6412. A relevant comment from there regarding future GPU support would be #6412 (comment).

I think '--outtype auto' might be fine, since at its core that's what it's doing

@bartowski1182 I agree, I've thought about this more, and I'll change this to auto instead of auto-f16. I don't think there will be auto-anything-else anyway.

is it possible to figure out the auto-chosen version earlier and use it for naming the outfile?

yes, this is already possible by checking the logs. It would also be possible to do automatically, but not everyone has the same naming conventions, so maybe the right way to do this would be with a .format() pattern? For example, --outfile llama-3-8b-instruct-{ftype}.gguf, or --outfile llama-3-8b-instruct-{outtype}.gguf or --outfile llama-3-8b-instruct-{}.gguf. Not sure which to support (all?), but it should be clearer than %s. It would also be possible to allow using {FTYPE} or {OUTTYPE} for upper-cased type names.

I see you've used fp16 and fp32 in the past, but this will use f16 and f32, respectively, for these type names.

(EDIT: this is now implemented in e0af2df)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i've made the transition to f16/f32, was an oversight from me naming them 'fp'

a format option would be amazing

Copy link
Collaborator Author

@compilade compilade May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should auto simply try to be as lossless as possible? Like, if the model is originally in f32, make the output f32? Or should it always select a 16-bit type? (currently bf16 is selected for f32 models)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote would be on compressing even if originally it was f32, if the person covering wants f32 they'll specify, otherwise presumably they're always converting with the intention of quantizing where it won't matter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Do no harm" should be the default.

help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -2472,9 +2427,10 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)

ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
}

if args.outfile is not None:
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .constants import *
from .lazy import *
from .gguf_reader import *
from .gguf_writer import *
from .tensor_mapping import *
Expand Down
41 changes: 41 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,47 @@ class GGMLQuantizationType(IntEnum):
BF16 = 30


# TODO: add GGMLFileType from ggml_ftype in ggml.h


# from llama_ftype in llama.h
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
class LlamaFileType(IntEnum):
ALL_F32 = 0
MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors
MOSTLY_Q5_0 = 8 # except 1d tensors
MOSTLY_Q5_1 = 9 # except 1d tensors
MOSTLY_Q2_K = 10 # except 1d tensors
MOSTLY_Q3_K_S = 11 # except 1d tensors
MOSTLY_Q3_K_M = 12 # except 1d tensors
MOSTLY_Q3_K_L = 13 # except 1d tensors
MOSTLY_Q4_K_S = 14 # except 1d tensors
MOSTLY_Q4_K_M = 15 # except 1d tensors
MOSTLY_Q5_K_S = 16 # except 1d tensors
MOSTLY_Q5_K_M = 17 # except 1d tensors
MOSTLY_Q6_K = 18 # except 1d tensors
MOSTLY_IQ2_XXS = 19 # except 1d tensors
MOSTLY_IQ2_XS = 20 # except 1d tensors
MOSTLY_Q2_K_S = 21 # except 1d tensors
MOSTLY_IQ3_XS = 22 # except 1d tensors
MOSTLY_IQ3_XXS = 23 # except 1d tensors
MOSTLY_IQ1_S = 24 # except 1d tensors
MOSTLY_IQ4_NL = 25 # except 1d tensors
MOSTLY_IQ3_S = 26 # except 1d tensors
MOSTLY_IQ3_M = 27 # except 1d tensors
MOSTLY_IQ2_S = 28 # except 1d tensors
MOSTLY_IQ2_M = 29 # except 1d tensors
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors


class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1
Expand Down
Loading