12
12
from enum import IntEnum
13
13
from pathlib import Path
14
14
from hashlib import sha256
15
- from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Sequence , TypeVar , cast
15
+ from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Sequence , TypeVar , cast , overload
16
16
17
17
import numpy as np
18
18
import torch
@@ -63,7 +63,7 @@ class Model:
63
63
# subclasses should define this!
64
64
model_arch : gguf .MODEL_ARCH
65
65
66
- def __init__ (self , dir_model : Path , ftype : int , fname_out : Path , is_big_endian : bool , use_temp_file : bool ):
66
+ def __init__ (self , dir_model : Path , ftype : int , fname_out : Path , is_big_endian : bool , use_temp_file : bool , eager : bool ):
67
67
if self .__class__ == Model :
68
68
raise TypeError (f"{ self .__class__ .__name__ !r} should not be directly instantiated" )
69
69
self .dir_model = dir_model
@@ -81,6 +81,9 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
81
81
self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" ])
82
82
self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
83
83
self .tensors = dict (self .get_tensors ())
84
+ if not eager :
85
+ for k , v in self .tensors .items ():
86
+ self .tensors [k ] = LazyTorchTensor .from_eager (v )
84
87
85
88
@classmethod
86
89
def __init_subclass__ (cls ):
@@ -245,9 +248,11 @@ def write_tensors(self):
245
248
246
249
def write (self ):
247
250
self .write_tensors ()
251
+ self .tensors .clear () # save memory by not keeping references to the tensors
252
+
248
253
self .gguf_writer .write_header_to_file ()
249
254
self .gguf_writer .write_kv_data_to_file ()
250
- self .gguf_writer .write_tensors_to_file ()
255
+ self .gguf_writer .write_tensors_to_file (progress = True )
251
256
self .gguf_writer .close ()
252
257
253
258
def write_vocab (self ):
@@ -2229,6 +2234,124 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
2229
2234
###### CONVERSION LOGIC ######
2230
2235
2231
2236
2237
+ # tree of lazy tensors
2238
+ class LazyTorchTensor :
2239
+ _meta : Tensor
2240
+ _data : Tensor | None
2241
+ _args : list [Any ]
2242
+ _func : Callable [[list [Any ]], Tensor ] | None = None
2243
+
2244
+ def __init__ (self , * , meta : Tensor , data : Tensor | None = None , args : list [Any ] | None = None , func : Callable [[list [Any ]], Tensor ] | None = None ):
2245
+ self ._meta = meta
2246
+ self ._data = data
2247
+ self ._args = args if args is not None else []
2248
+ self ._func = func
2249
+
2250
+ @staticmethod
2251
+ def _recurse_apply (o : Any , fn : Callable [[Any ], Any ]) -> Any :
2252
+ # TODO: dicts
2253
+ if isinstance (o , (list , tuple )):
2254
+ l = []
2255
+ for item in o :
2256
+ l .append (LazyTorchTensor ._recurse_apply (item , fn ))
2257
+ if isinstance (o , tuple ):
2258
+ l = tuple (l )
2259
+ return l
2260
+ elif isinstance (o , LazyTorchTensor ):
2261
+ return fn (o )
2262
+ else :
2263
+ return o
2264
+
2265
+ def _wrap_fn (self , fn : Callable , use_self : bool = False ) -> Callable [[Any ], LazyTorchTensor ]:
2266
+ def wrapped_fn (* args , ** kwargs ):
2267
+ if kwargs is None :
2268
+ kwargs = {}
2269
+ args_list = ([self ] if use_self else []) + list (args )
2270
+
2271
+ meta_args = LazyTorchTensor ._recurse_apply (args_list , lambda t : t ._meta )
2272
+
2273
+ return LazyTorchTensor (meta = fn (* meta_args , ** kwargs ), args = args_list , func = lambda a : fn (* a , ** kwargs ))
2274
+ return wrapped_fn
2275
+
2276
+ def __getattr__ (self , __name : str ) -> Any :
2277
+ meta_attr = getattr (self ._meta , __name )
2278
+ if not callable (meta_attr ):
2279
+ return meta_attr
2280
+ else :
2281
+ return self ._wrap_fn (getattr (torch .Tensor , __name ), use_self = True )
2282
+
2283
+ _dtype_map : dict [torch .dtype , type ] = {
2284
+ torch .float16 : np .float16 ,
2285
+ torch .float32 : np .float32 ,
2286
+ }
2287
+
2288
+ def numpy (self ) -> gguf .LazyTensor :
2289
+ dtype = self ._dtype_map [self .dtype ]
2290
+ return gguf .LazyTensor (lambda : LazyTorchTensor .to_eager (self ).numpy (), dtype = dtype , shape = self .shape )
2291
+
2292
+ @overload
2293
+ @staticmethod
2294
+ def to_eager (t : Tensor | LazyTorchTensor ) -> Tensor : ...
2295
+
2296
+ @overload
2297
+ @staticmethod
2298
+ def to_eager (t : list [Tensor | LazyTorchTensor ]) -> list [Tensor ]: ...
2299
+
2300
+ @staticmethod
2301
+ def to_eager (t : Any ) -> Any :
2302
+ def simple_to_eager (_t : LazyTorchTensor ) -> Tensor :
2303
+ # wake up the lazy tensor
2304
+ if _t ._data is None and _t ._func is not None :
2305
+ # recurse into its arguments
2306
+ _t ._args = LazyTorchTensor .to_eager (_t ._args )
2307
+ _t ._data = _t ._func (_t ._args )
2308
+ if _t ._data is not None :
2309
+ return _t ._data
2310
+ else :
2311
+ raise ValueError (f"Could not compute lazy tensor { _t !r} with args { _t ._args !r} " )
2312
+
2313
+ # recurse into lists and/or tuples, keeping their structure
2314
+ return LazyTorchTensor ._recurse_apply (t , simple_to_eager )
2315
+
2316
+ @staticmethod
2317
+ def from_eager (t : Tensor ) -> Tensor :
2318
+ if (t .__class__ == LazyTorchTensor ):
2319
+ return t
2320
+ return LazyTorchTensor (meta = t .detach ().to ("meta" ), data = t ) # type: ignore
2321
+
2322
+ @classmethod
2323
+ def __torch_function__ (cls , func , types , args = (), kwargs = None ):
2324
+ del types # unused
2325
+
2326
+ if kwargs is None :
2327
+ kwargs = {}
2328
+
2329
+ if func is torch .Tensor .numpy :
2330
+ return args [0 ].numpy ()
2331
+ if func is torch .equal :
2332
+ eager_args = LazyTorchTensor .to_eager (args )
2333
+ return func (* eager_args , ** kwargs )
2334
+
2335
+ return LazyTorchTensor ._wrap_fn (args [0 ], func )(* args , ** kwargs )
2336
+
2337
+ # special methods bypass __getattr__, so they need to be added manually
2338
+ # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
2339
+ # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
2340
+ # as self._meta is currently used), because then the following
2341
+ # operations would by default not be wrapped, and so not propagated
2342
+ # when the tensor is made eager.
2343
+ # It's better to get non-silent errors for not-yet-supported operators.
2344
+ # TODO: add more when needed to avoid clutter, or find a more concise way
2345
+ def __neg__ (self , * args ): # mamba
2346
+ return self ._wrap_fn (torch .Tensor .__neg__ )(self , * args )
2347
+
2348
+ def __add__ (self , * args ): # gemma
2349
+ return self ._wrap_fn (torch .Tensor .__add__ )(self , * args )
2350
+
2351
+ def __getitem__ (self , * args ): # bloom falcon internlm2
2352
+ return self ._wrap_fn (torch .Tensor .__getitem__ )(self , * args )
2353
+
2354
+
2232
2355
def parse_args () -> argparse .Namespace :
2233
2356
parser = argparse .ArgumentParser (
2234
2357
description = "Convert a huggingface model to a GGML compatible file" )
@@ -2260,6 +2383,10 @@ def parse_args() -> argparse.Namespace:
2260
2383
"--use-temp-file" , action = "store_true" ,
2261
2384
help = "use the tempfile library while processing (helpful when running out of memory, process killed)" ,
2262
2385
)
2386
+ parser .add_argument (
2387
+ "--no-lazy" , action = "store_true" ,
2388
+ help = "use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)" ,
2389
+ )
2263
2390
parser .add_argument (
2264
2391
"--model-name" , type = str , default = None ,
2265
2392
help = "name of the model" ,
@@ -2313,7 +2440,7 @@ def main() -> None:
2313
2440
2314
2441
with torch .inference_mode ():
2315
2442
model_class = Model .from_model_architecture (hparams ["architectures" ][0 ])
2316
- model_instance = model_class (dir_model , ftype_map [args .outtype ], fname_out , args .bigendian , args .use_temp_file )
2443
+ model_instance = model_class (dir_model , ftype_map [args .outtype ], fname_out , args .bigendian , args .use_temp_file , args . no_lazy )
2317
2444
2318
2445
logger .info ("Set model parameters" )
2319
2446
model_instance .set_gguf_parameters ()
0 commit comments