18
18
List ,
19
19
Literal ,
20
20
Optional ,
21
+ Tuple ,
21
22
Union ,
22
23
Generator ,
23
24
Sequence ,
@@ -352,7 +353,9 @@ def __init__(
352
353
353
354
self .cache : Optional [BaseLlamaCache ] = None
354
355
355
- self .lora_adapters = lora_adapters
356
+ self .lora_adapters = (
357
+ lora_adapters if lora_adapters is None else {}
358
+ )
356
359
357
360
self .spm_infill = spm_infill
358
361
@@ -401,10 +404,13 @@ def __init__(
401
404
)
402
405
)
403
406
404
- self ._lora_adapters_by_path : Dict [str , internals .LlamaLoraAdapter ] = {}
405
- if self .lora_adapters :
406
- for lora_path , scale in self .lora_adapters .items ():
407
- self .set_lora_adapter (lora_path , scale , load_if_needed = True )
407
+ # Dict from LoRA path to wrapper
408
+ self ._lora_adapters_paths : Dict [str , internals .LlamaLoraAdapter ] = {}
409
+ # Immutable value representing active adapters for use as a key
410
+ self ._lora_adapters_active : Tuple [Tuple [str , float ]] = ()
411
+
412
+ for lora_path , scale in self .lora_adapters .copy ().items ():
413
+ self .set_lora_adapter_scale (lora_path , scale , load_if_needed = True )
408
414
409
415
if self .verbose :
410
416
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
@@ -426,6 +432,7 @@ def __init__(
426
432
self ._candidates = internals .LlamaTokenDataArray (n_vocab = self ._n_vocab )
427
433
428
434
self .n_tokens = 0
435
+ self .tokens_lora_adapters : Tuple [Tuple [str , float ]] = () # Adapters that processed tokens
429
436
self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
430
437
self .scores : npt .NDArray [np .single ] = np .ndarray (
431
438
(n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
@@ -595,7 +602,7 @@ def set_seed(self, seed: int):
595
602
"""
596
603
self ._seed = seed
597
604
598
- def set_lora_adapter (self , lora_path : str , scale : float , * , load_if_needed = False ):
605
+ def set_lora_adapter_scale (self , lora_path : str , scale : float , * , load_if_needed = False ):
599
606
"""
600
607
Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601
608
has previously been loaded then this method will set its scale. If the LoRA adapter file has
@@ -608,7 +615,8 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
608
615
method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609
616
hasn't already been loaded will raise an exception.
610
617
"""
611
- lora_adapter = self ._lora_adapters_by_path .get (lora_path )
618
+ # Load adapter if needed (even if scale 0.0)
619
+ lora_adapter = self ._lora_adapters_paths .get (lora_path )
612
620
if lora_adapter is None :
613
621
lora_adapter = internals .LlamaLoraAdapter (
614
622
self ._model ,
@@ -619,15 +627,24 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
619
627
raise RuntimeError (
620
628
f"Failed to initialize LoRA adapter from lora path: { lora_path } "
621
629
)
622
- self ._lora_adapters_by_path [lora_path ] = lora_adapter
630
+ self ._lora_adapters_paths [lora_path ] = lora_adapter
631
+
623
632
if scale == 0.0 :
624
- self ._ctx .lora_adapter_remove (lora_adapter ) # Safe even if not in context
633
+ # Remove from context; safe to call even if not in context
634
+ self ._ctx .lora_adapter_remove (lora_adapter )
625
635
else :
636
+ # Set scale in context
626
637
self ._ctx .lora_adapter_set (lora_adapter , scale )
627
638
639
+ self .lora_adapters [lora_path ] = scale
640
+ self ._lora_adapters_active = tuple (sorted (
641
+ filter (lambda path_scale : path_scale [1 ] != 0.0 , self .lora_adapters .items ())
642
+ ))
643
+
628
644
def reset (self ):
629
645
"""Reset the model state."""
630
646
self .n_tokens = 0
647
+ self .tokens_lora_adapters = self ._lora_adapters_active
631
648
632
649
def eval (self , tokens : Sequence [int ]):
633
650
"""Evaluate a list of tokens.
@@ -879,7 +896,7 @@ def generate(
879
896
)
880
897
881
898
# Check for kv cache prefix match
882
- if reset and self .n_tokens > 0 :
899
+ if reset and self .n_tokens > 0 and self . tokens_lora_adapters == self . _lora_adapters_active :
883
900
longest_prefix = 0
884
901
for a , b in zip (self ._input_ids , tokens [:- 1 ]):
885
902
if a == b :
@@ -1296,7 +1313,7 @@ def logit_bias_processor(
1296
1313
1297
1314
if self .cache :
1298
1315
try :
1299
- cache_item = self .cache [prompt_tokens ]
1316
+ cache_item = self .cache [( self . _lora_adapters_active , prompt_tokens ) ]
1300
1317
cache_prefix_len = Llama .longest_token_prefix (
1301
1318
cache_item .input_ids .tolist (), prompt_tokens
1302
1319
)
@@ -1634,15 +1651,15 @@ def logit_bias_processor(
1634
1651
if self .cache :
1635
1652
if self .verbose :
1636
1653
print ("Llama._create_completion: cache save" , file = sys .stderr )
1637
- self .cache [prompt_tokens + completion_tokens ] = self .save_state ()
1654
+ self .cache [( self . _lora_adapters_active , prompt_tokens + completion_tokens ) ] = self .save_state ()
1638
1655
if self .verbose :
1639
1656
print ("Llama._create_completion: cache saved" , file = sys .stderr )
1640
1657
return
1641
1658
1642
1659
if self .cache :
1643
1660
if self .verbose :
1644
1661
print ("Llama._create_completion: cache save" , file = sys .stderr )
1645
- self .cache [prompt_tokens + completion_tokens ] = self .save_state ()
1662
+ self .cache [( self . _lora_adapters_active , prompt_tokens + completion_tokens ) ] = self .save_state ()
1646
1663
1647
1664
text_str = text .decode ("utf-8" , errors = "ignore" )
1648
1665
0 commit comments