Skip to content

Commit 866865e

Browse files
committed
convert-hf : save memory with lazy evaluation
1 parent f2099c5 commit 866865e

File tree

2 files changed

+196
-10
lines changed

2 files changed

+196
-10
lines changed

convert-hf-to-gguf.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from enum import IntEnum
1313
from pathlib import Path
1414
from hashlib import sha256
15-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
15+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
1616

1717
import numpy as np
1818
import torch
@@ -63,7 +63,7 @@ class Model:
6363
# subclasses should define this!
6464
model_arch: gguf.MODEL_ARCH
6565

66-
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
66+
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
6767
if self.__class__ == Model:
6868
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
6969
self.dir_model = dir_model
@@ -81,6 +81,9 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
8181
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
8282
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
8383
self.tensors = dict(self.get_tensors())
84+
if not eager:
85+
for k, v in self.tensors.items():
86+
self.tensors[k] = LazyTorchTensor.from_eager(v)
8487

8588
@classmethod
8689
def __init_subclass__(cls):
@@ -245,9 +248,11 @@ def write_tensors(self):
245248

246249
def write(self):
247250
self.write_tensors()
251+
self.tensors.clear() # save memory by not keeping references to the tensors
252+
248253
self.gguf_writer.write_header_to_file()
249254
self.gguf_writer.write_kv_data_to_file()
250-
self.gguf_writer.write_tensors_to_file()
255+
self.gguf_writer.write_tensors_to_file(progress=True)
251256
self.gguf_writer.close()
252257

253258
def write_vocab(self):
@@ -2229,6 +2234,124 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22292234
###### CONVERSION LOGIC ######
22302235

22312236

2237+
# tree of lazy tensors
2238+
class LazyTorchTensor:
2239+
_meta: Tensor
2240+
_data: Tensor | None
2241+
_args: list[Any]
2242+
_func: Callable[[list[Any]], Tensor] | None = None
2243+
2244+
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
2245+
self._meta = meta
2246+
self._data = data
2247+
self._args = args if args is not None else []
2248+
self._func = func
2249+
2250+
@staticmethod
2251+
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
2252+
# TODO: dicts
2253+
if isinstance(o, (list, tuple)):
2254+
l = []
2255+
for item in o:
2256+
l.append(LazyTorchTensor._recurse_apply(item, fn))
2257+
if isinstance(o, tuple):
2258+
l = tuple(l)
2259+
return l
2260+
elif isinstance(o, LazyTorchTensor):
2261+
return fn(o)
2262+
else:
2263+
return o
2264+
2265+
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
2266+
def wrapped_fn(*args, **kwargs):
2267+
if kwargs is None:
2268+
kwargs = {}
2269+
args_list = ([self] if use_self else []) + list(args)
2270+
2271+
meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
2272+
2273+
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
2274+
return wrapped_fn
2275+
2276+
def __getattr__(self, __name: str) -> Any:
2277+
meta_attr = getattr(self._meta, __name)
2278+
if not callable(meta_attr):
2279+
return meta_attr
2280+
else:
2281+
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
2282+
2283+
_dtype_map: dict[torch.dtype, type] = {
2284+
torch.float16: np.float16,
2285+
torch.float32: np.float32,
2286+
}
2287+
2288+
def numpy(self) -> gguf.LazyTensor:
2289+
dtype = self._dtype_map[self.dtype]
2290+
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
2291+
2292+
@overload
2293+
@staticmethod
2294+
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
2295+
2296+
@overload
2297+
@staticmethod
2298+
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
2299+
2300+
@staticmethod
2301+
def to_eager(t: Any) -> Any:
2302+
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
2303+
# wake up the lazy tensor
2304+
if _t._data is None and _t._func is not None:
2305+
# recurse into its arguments
2306+
_t._args = LazyTorchTensor.to_eager(_t._args)
2307+
_t._data = _t._func(_t._args)
2308+
if _t._data is not None:
2309+
return _t._data
2310+
else:
2311+
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
2312+
2313+
# recurse into lists and/or tuples, keeping their structure
2314+
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
2315+
2316+
@staticmethod
2317+
def from_eager(t: Tensor) -> Tensor:
2318+
if (t.__class__ == LazyTorchTensor):
2319+
return t
2320+
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
2321+
2322+
@classmethod
2323+
def __torch_function__(cls, func, types, args=(), kwargs=None):
2324+
del types # unused
2325+
2326+
if kwargs is None:
2327+
kwargs = {}
2328+
2329+
if func is torch.Tensor.numpy:
2330+
return args[0].numpy()
2331+
if func is torch.equal:
2332+
eager_args = LazyTorchTensor.to_eager(args)
2333+
return func(*eager_args, **kwargs)
2334+
2335+
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
2336+
2337+
# special methods bypass __getattr__, so they need to be added manually
2338+
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
2339+
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
2340+
# as self._meta is currently used), because then the following
2341+
# operations would by default not be wrapped, and so not propagated
2342+
# when the tensor is made eager.
2343+
# It's better to get non-silent errors for not-yet-supported operators.
2344+
# TODO: add more when needed to avoid clutter, or find a more concise way
2345+
def __neg__(self, *args): # mamba
2346+
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
2347+
2348+
def __add__(self, *args): # gemma
2349+
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
2350+
2351+
def __getitem__(self, *args): # bloom falcon internlm2
2352+
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
2353+
2354+
22322355
def parse_args() -> argparse.Namespace:
22332356
parser = argparse.ArgumentParser(
22342357
description="Convert a huggingface model to a GGML compatible file")
@@ -2260,6 +2383,10 @@ def parse_args() -> argparse.Namespace:
22602383
"--use-temp-file", action="store_true",
22612384
help="use the tempfile library while processing (helpful when running out of memory, process killed)",
22622385
)
2386+
parser.add_argument(
2387+
"--no-lazy", action="store_true",
2388+
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
2389+
)
22632390
parser.add_argument(
22642391
"--model-name", type=str, default=None,
22652392
help="name of the model",
@@ -2313,7 +2440,7 @@ def main() -> None:
23132440

23142441
with torch.inference_mode():
23152442
model_class = Model.from_model_architecture(hparams["architectures"][0])
2316-
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file)
2443+
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
23172444

23182445
logger.info("Set model parameters")
23192446
model_instance.set_gguf_parameters()

gguf-py/gguf/gguf_writer.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tempfile
88
from enum import Enum, auto
99
from io import BufferedWriter
10-
from typing import IO, Any, Sequence, Mapping
10+
from typing import IO, Any, Callable, Sequence, Mapping
1111
from string import ascii_letters, digits
1212

1313
import numpy as np
@@ -28,6 +28,47 @@
2828
logger = logging.getLogger(__name__)
2929

3030

31+
class LazyTensor:
32+
data: Callable[[], np.ndarray[Any, Any]]
33+
# to avoid too deep recursion
34+
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
35+
dtype: np.dtype[Any]
36+
shape: tuple[int, ...]
37+
38+
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
39+
self.data = data
40+
self.functions = []
41+
self.dtype = np.dtype(dtype)
42+
self.shape = shape
43+
44+
def astype(self, dtype: type, **kwargs) -> LazyTensor:
45+
self.functions.append(lambda n: n.astype(dtype, **kwargs))
46+
self.dtype = np.dtype(dtype)
47+
return self
48+
49+
@property
50+
def nbytes(self) -> int:
51+
size = 1
52+
for n in self.shape:
53+
size *= n
54+
return size * self.dtype.itemsize
55+
56+
def tofile(self, *args, **kwargs) -> None:
57+
data = self.data()
58+
for f in self.functions:
59+
data = f(data)
60+
assert data.shape == self.shape
61+
assert data.dtype == self.dtype
62+
assert data.nbytes == self.nbytes
63+
self.functions = []
64+
self.data = lambda: data
65+
data.tofile(*args, **kwargs)
66+
67+
def byteswap(self, *args, **kwargs) -> LazyTensor:
68+
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
69+
return self
70+
71+
3172
class WriterState(Enum):
3273
EMPTY = auto()
3374
HEADER = auto()
@@ -38,7 +79,7 @@ class WriterState(Enum):
3879
class GGUFWriter:
3980
fout: BufferedWriter
4081
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
41-
tensors: list[np.ndarray[Any, Any]]
82+
tensors: list[np.ndarray[Any, Any] | LazyTensor]
4283
_simple_value_packing = {
4384
GGUFValueType.UINT8: "B",
4485
GGUFValueType.INT8: "b",
@@ -237,7 +278,7 @@ def add_tensor_info(
237278
self.ti_data_count += 1
238279

239280
def add_tensor(
240-
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
281+
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
241282
raw_dtype: GGMLQuantizationType | None = None,
242283
) -> None:
243284
if self.endianess == GGUFEndian.BIG:
@@ -262,7 +303,7 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None
262303
if pad != 0:
263304
fp.write(bytes([0] * pad))
264305

265-
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
306+
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
266307
if self.state is not WriterState.TI_DATA:
267308
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
268309

@@ -272,15 +313,33 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
272313
tensor.tofile(self.fout)
273314
self.write_padding(self.fout, tensor.nbytes)
274315

275-
def write_tensors_to_file(self) -> None:
316+
def write_tensors_to_file(self, *, progress: bool = False) -> None:
276317
self.write_ti_data_to_file()
277318

278319
self.write_padding(self.fout, self.fout.tell())
279320

280321
if self.temp_file is None:
322+
self.tensors.reverse() # to pop from the "beginning" in constant time
323+
324+
if progress:
325+
from tqdm import tqdm
326+
327+
total_bytes = sum(t.nbytes for t in self.tensors)
328+
329+
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
330+
331+
while True:
332+
try:
333+
tensor = self.tensors.pop()
334+
except IndexError:
335+
break
336+
tensor.tofile(self.fout)
337+
bar.update(tensor.nbytes)
338+
self.write_padding(self.fout, tensor.nbytes)
339+
return
281340
while True:
282341
try:
283-
tensor = self.tensors.pop(0)
342+
tensor = self.tensors.pop()
284343
except IndexError:
285344
break
286345
tensor.tofile(self.fout)

0 commit comments

Comments
 (0)