@@ -96,9 +96,7 @@ def __init__(
96
96
# Sampling Params
97
97
last_n_tokens_size : int = 64 ,
98
98
# LoRA Params
99
- lora_base : Optional [str ] = None ,
100
- lora_scale : float = 1.0 ,
101
- lora_path : Optional [str ] = None ,
99
+ lora_adapters : Optional [Dict [str , float ]] = None ,
102
100
# Backend Params
103
101
numa : Union [bool , int ] = False ,
104
102
# Chat Format Params
@@ -174,8 +172,7 @@ def __init__(
174
172
offload_kqv: Offload K, Q, V to GPU.
175
173
flash_attn: Use flash attention.
176
174
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
177
- lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
178
- lora_path: Path to a LoRA file to apply to the model.
175
+ lora_adapters: Paths to LoRA adapter files and the scale to apply to them at (scale of 0.0 will not be used during inference).
179
176
numa: numa policy
180
177
chat_format: String specifying the chat format to use when calling create_chat_completion.
181
178
chat_handler: Optional chat handler to use when calling create_chat_completion.
@@ -243,7 +240,7 @@ def __init__(
243
240
) # keep a reference to the array so it is not gc'd
244
241
self .model_params .tensor_split = self ._c_tensor_split
245
242
self .model_params .vocab_only = vocab_only
246
- self .model_params .use_mmap = use_mmap if lora_path is None else False
243
+ self .model_params .use_mmap = use_mmap
247
244
self .model_params .use_mlock = use_mlock
248
245
249
246
# kv_overrides is the original python dict
@@ -355,9 +352,7 @@ def __init__(
355
352
356
353
self .cache : Optional [BaseLlamaCache ] = None
357
354
358
- self .lora_base = lora_base
359
- self .lora_scale = lora_scale
360
- self .lora_path = lora_path
355
+ self .lora_adapters = lora_adapters
361
356
362
357
self .spm_infill = spm_infill
363
358
@@ -406,32 +401,10 @@ def __init__(
406
401
)
407
402
)
408
403
409
- self ._lora_adapter : Optional [llama_cpp .llama_lora_adapter_p ] = None
410
-
411
- if self .lora_path :
412
- self ._lora_adapter = llama_cpp .llama_lora_adapter_init (
413
- self ._model .model ,
414
- self .lora_path .encode ("utf-8" ),
415
- )
416
- if self ._lora_adapter is None :
417
- raise RuntimeError (
418
- f"Failed to initialize LoRA adapter from lora path: { self .lora_path } "
419
- )
420
-
421
- def free_lora_adapter ():
422
- if self ._lora_adapter is None :
423
- return
424
- llama_cpp .llama_lora_adapter_free (self ._lora_adapter )
425
- self ._lora_adapter = None
426
-
427
- self ._stack .callback (free_lora_adapter )
428
-
429
- if llama_cpp .llama_lora_adapter_set (
430
- self ._ctx .ctx , self ._lora_adapter , self .lora_scale
431
- ):
432
- raise RuntimeError (
433
- f"Failed to set LoRA adapter from lora path: { self .lora_path } "
434
- )
404
+ self ._lora_adapters_by_path : Dict [str , internals .LlamaLoraAdapter ] = {}
405
+ if self .lora_adapters :
406
+ for lora_path , scale in self .lora_adapters .items ():
407
+ self .set_lora_adapter (lora_path , scale , load_if_needed = True )
435
408
436
409
if self .verbose :
437
410
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
@@ -621,6 +594,36 @@ def set_seed(self, seed: int):
621
594
seed: The random seed.
622
595
"""
623
596
self ._seed = seed
597
+
598
+ def set_lora_adapter (self , lora_path : str , scale : float , * , load_if_needed = False ):
599
+ """
600
+ Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601
+ has previously been loaded then this method will set its scale. If the LoRA adapter file has
602
+ not been previously loaded, this method will raise an exception, unless load_if_needed is set.
603
+
604
+ Args:
605
+ lora_path: The path to the LoRA adapter. This path must have been loaded when the `Llama` object was created.
606
+ scale: The scaling factor to apply to the LoRA adapter. If 0.0, the LoRA adapter will be disabled so it won't be used during inference.
607
+ load_if_needed: Whether or not to load the adapter if it has not been previously been loaded. If True, this
608
+ method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609
+ hasn't already been loaded will raise an exception.
610
+ """
611
+ lora_adapter = self ._lora_adapters_by_path .get (lora_path )
612
+ if lora_adapter is None :
613
+ lora_adapter = internals .LlamaLoraAdapter (
614
+ self ._model ,
615
+ lora_path ,
616
+ verbose = self .verbose ,
617
+ )
618
+ if lora_adapter is None :
619
+ raise RuntimeError (
620
+ f"Failed to initialize LoRA adapter from lora path: { lora_path } "
621
+ )
622
+ self ._lora_adapters_by_path [lora_path ] = lora_adapter
623
+ if scale == 0.0 :
624
+ self ._ctx .lora_adapter_remove (lora_adapter ) # Safe even if not in context
625
+ else :
626
+ self ._ctx .lora_adapter_set (lora_adapter , scale )
624
627
625
628
def reset (self ):
626
629
"""Reset the model state."""
@@ -2096,9 +2099,7 @@ def __getstate__(self):
2096
2099
# Sampling Params
2097
2100
last_n_tokens_size = self .last_n_tokens_size ,
2098
2101
# LoRA Params
2099
- lora_base = self .lora_base ,
2100
- lora_scale = self .lora_scale ,
2101
- lora_path = self .lora_path ,
2102
+ lora_adapters = self .lora_adapters ,
2102
2103
# Backend Params
2103
2104
numa = self .numa ,
2104
2105
# Chat Format Params
0 commit comments