18
18
List ,
19
19
Literal ,
20
20
Optional ,
21
+ Tuple ,
21
22
Union ,
22
23
Generator ,
23
24
Sequence ,
@@ -352,7 +353,9 @@ def __init__(
352
353
353
354
self .cache : Optional [BaseLlamaCache ] = None
354
355
355
- self .lora_adapters = lora_adapters
356
+ self .lora_adapters = (
357
+ lora_adapters if lora_adapters is None else {}
358
+ )
356
359
357
360
self .spm_infill = spm_infill
358
361
@@ -401,10 +404,13 @@ def __init__(
401
404
)
402
405
)
403
406
404
- self ._lora_adapters_by_path : Dict [str , internals .LlamaLoraAdapter ] = {}
405
- if self .lora_adapters :
406
- for lora_path , scale in self .lora_adapters .items ():
407
- self .set_lora_adapter (lora_path , scale , load_if_needed = True )
407
+ # Dict from LoRA path to wrapper
408
+ self ._lora_adapters_paths : Dict [str , internals .LlamaLoraAdapter ] = {}
409
+ # Immutable value representing active adapters for use as a key
410
+ self ._lora_adapters_active : Tuple [Tuple [str , float ]] = ()
411
+
412
+ for lora_path , scale in self .lora_adapters .copy ().items ():
413
+ self .set_lora_adapter_scale (lora_path , scale , load_if_needed = True )
408
414
409
415
if self .verbose :
410
416
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
@@ -426,6 +432,7 @@ def __init__(
426
432
self ._candidates = internals .LlamaTokenDataArray (n_vocab = self ._n_vocab )
427
433
428
434
self .n_tokens = 0
435
+ self .tokens_lora_adapters : Tuple [Tuple [str , float ]] = () # Adapters that processed tokens
429
436
self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
430
437
self .scores : npt .NDArray [np .single ] = np .ndarray (
431
438
(n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
@@ -595,7 +602,7 @@ def set_seed(self, seed: int):
595
602
"""
596
603
self ._seed = seed
597
604
598
- def set_lora_adapter (self , lora_path : str , scale : float , * , load_if_needed = False ):
605
+ def set_lora_adapter_scale (self , lora_path : str , scale : float , * , load_if_needed = False ):
599
606
"""
600
607
Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601
608
has previously been loaded then this method will set its scale. If the LoRA adapter file has
@@ -608,7 +615,8 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
608
615
method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609
616
hasn't already been loaded will raise an exception.
610
617
"""
611
- lora_adapter = self ._lora_adapters_by_path .get (lora_path )
618
+ # Load adapter if needed (even if scale 0.0)
619
+ lora_adapter = self ._lora_adapters_paths .get (lora_path )
612
620
if lora_adapter is None :
613
621
lora_adapter = internals .LlamaLoraAdapter (
614
622
self ._model ,
@@ -619,15 +627,24 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
619
627
raise RuntimeError (
620
628
f"Failed to initialize LoRA adapter from lora path: { lora_path } "
621
629
)
622
- self ._lora_adapters_by_path [lora_path ] = lora_adapter
630
+ self ._lora_adapters_paths [lora_path ] = lora_adapter
631
+
623
632
if scale == 0.0 :
624
- self ._ctx .lora_adapter_remove (lora_adapter ) # Safe even if not in context
633
+ # Remove from context; safe to call even if not in context
634
+ self ._ctx .lora_adapter_remove (lora_adapter )
625
635
else :
636
+ # Set scale in context
626
637
self ._ctx .lora_adapter_set (lora_adapter , scale )
627
638
639
+ self .lora_adapters [lora_path ] = scale
640
+ self ._lora_adapters_active = tuple (sorted (
641
+ filter (lambda path_scale : path_scale [1 ] != 0.0 , self .lora_adapters .items ())
642
+ ))
643
+
628
644
def reset (self ):
629
645
"""Reset the model state."""
630
646
self .n_tokens = 0
647
+ self .tokens_lora_adapters = self ._lora_adapters_active
631
648
632
649
def eval (self , tokens : Sequence [int ]):
633
650
"""Evaluate a list of tokens.
@@ -878,7 +895,7 @@ def generate(
878
895
)
879
896
880
897
# Check for kv cache prefix match
881
- if reset and self .n_tokens > 0 :
898
+ if reset and self .n_tokens > 0 and self . tokens_lora_adapters == self . _lora_adapters_active :
882
899
longest_prefix = 0
883
900
for a , b in zip (self ._input_ids , tokens [:- 1 ]):
884
901
if a == b :
@@ -1295,7 +1312,7 @@ def logit_bias_processor(
1295
1312
1296
1313
if self .cache :
1297
1314
try :
1298
- cache_item = self .cache [prompt_tokens ]
1315
+ cache_item = self .cache [( self . _lora_adapters_active , prompt_tokens ) ]
1299
1316
cache_prefix_len = Llama .longest_token_prefix (
1300
1317
cache_item .input_ids .tolist (), prompt_tokens
1301
1318
)
@@ -1633,15 +1650,15 @@ def logit_bias_processor(
1633
1650
if self .cache :
1634
1651
if self .verbose :
1635
1652
print ("Llama._create_completion: cache save" , file = sys .stderr )
1636
- self .cache [prompt_tokens + completion_tokens ] = self .save_state ()
1653
+ self .cache [( self . _lora_adapters_active , prompt_tokens + completion_tokens ) ] = self .save_state ()
1637
1654
if self .verbose :
1638
1655
print ("Llama._create_completion: cache saved" , file = sys .stderr )
1639
1656
return
1640
1657
1641
1658
if self .cache :
1642
1659
if self .verbose :
1643
1660
print ("Llama._create_completion: cache save" , file = sys .stderr )
1644
- self .cache [prompt_tokens + completion_tokens ] = self .save_state ()
1661
+ self .cache [( self . _lora_adapters_active , prompt_tokens + completion_tokens ) ] = self .save_state ()
1645
1662
1646
1663
text_str = text .decode ("utf-8" , errors = "ignore" )
1647
1664
0 commit comments