Skip to content

Commit bc48b50

Browse files
committed
fix: Caching for hot-swapping LoRA adapters
1 parent a6a6b8c commit bc48b50

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ High-level Python bindings for llama.cpp.
2222
- __call__
2323
- create_chat_completion
2424
- create_chat_completion_openai_v1
25+
- set_lora_adapter_scale
2526
- set_cache
2627
- save_state
2728
- load_state

llama_cpp/llama.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
List,
1919
Literal,
2020
Optional,
21+
Tuple,
2122
Union,
2223
Generator,
2324
Sequence,
@@ -352,7 +353,9 @@ def __init__(
352353

353354
self.cache: Optional[BaseLlamaCache] = None
354355

355-
self.lora_adapters = lora_adapters
356+
self.lora_adapters = (
357+
lora_adapters if lora_adapters is None else {}
358+
)
356359

357360
self.spm_infill = spm_infill
358361

@@ -401,10 +404,13 @@ def __init__(
401404
)
402405
)
403406

404-
self._lora_adapters_by_path: Dict[str, internals.LlamaLoraAdapter] = {}
405-
if self.lora_adapters:
406-
for lora_path, scale in self.lora_adapters.items():
407-
self.set_lora_adapter(lora_path, scale, load_if_needed=True)
407+
# Dict from LoRA path to wrapper
408+
self._lora_adapters_paths: Dict[str, internals.LlamaLoraAdapter] = {}
409+
# Immutable value representing active adapters for use as a key
410+
self._lora_adapters_active: Tuple[Tuple[str, float]] = ()
411+
412+
for lora_path, scale in self.lora_adapters.copy().items():
413+
self.set_lora_adapter_scale(lora_path, scale, load_if_needed=True)
408414

409415
if self.verbose:
410416
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
@@ -426,6 +432,7 @@ def __init__(
426432
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
427433

428434
self.n_tokens = 0
435+
self.tokens_lora_adapters: Tuple[Tuple[str, float]] = () # Adapters that processed tokens
429436
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
430437
self.scores: npt.NDArray[np.single] = np.ndarray(
431438
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
@@ -595,7 +602,7 @@ def set_seed(self, seed: int):
595602
"""
596603
self._seed = seed
597604

598-
def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False):
605+
def set_lora_adapter_scale(self, lora_path: str, scale: float, *, load_if_needed=False):
599606
"""
600607
Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601608
has previously been loaded then this method will set its scale. If the LoRA adapter file has
@@ -608,7 +615,8 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
608615
method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609616
hasn't already been loaded will raise an exception.
610617
"""
611-
lora_adapter = self._lora_adapters_by_path.get(lora_path)
618+
# Load adapter if needed (even if scale 0.0)
619+
lora_adapter = self._lora_adapters_paths.get(lora_path)
612620
if lora_adapter is None:
613621
lora_adapter = internals.LlamaLoraAdapter(
614622
self._model,
@@ -619,15 +627,24 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
619627
raise RuntimeError(
620628
f"Failed to initialize LoRA adapter from lora path: {lora_path}"
621629
)
622-
self._lora_adapters_by_path[lora_path] = lora_adapter
630+
self._lora_adapters_paths[lora_path] = lora_adapter
631+
623632
if scale == 0.0:
624-
self._ctx.lora_adapter_remove(lora_adapter) # Safe even if not in context
633+
# Remove from context; safe to call even if not in context
634+
self._ctx.lora_adapter_remove(lora_adapter)
625635
else:
636+
# Set scale in context
626637
self._ctx.lora_adapter_set(lora_adapter, scale)
627638

639+
self.lora_adapters[lora_path] = scale
640+
self._lora_adapters_active = tuple(sorted(
641+
filter(lambda path_scale: path_scale[1] != 0.0, self.lora_adapters.items())
642+
))
643+
628644
def reset(self):
629645
"""Reset the model state."""
630646
self.n_tokens = 0
647+
self.tokens_lora_adapters = self._lora_adapters_active
631648

632649
def eval(self, tokens: Sequence[int]):
633650
"""Evaluate a list of tokens.
@@ -878,7 +895,7 @@ def generate(
878895
)
879896

880897
# Check for kv cache prefix match
881-
if reset and self.n_tokens > 0:
898+
if reset and self.n_tokens > 0 and self.tokens_lora_adapters == self._lora_adapters_active:
882899
longest_prefix = 0
883900
for a, b in zip(self._input_ids, tokens[:-1]):
884901
if a == b:
@@ -1295,7 +1312,7 @@ def logit_bias_processor(
12951312

12961313
if self.cache:
12971314
try:
1298-
cache_item = self.cache[prompt_tokens]
1315+
cache_item = self.cache[(self._lora_adapters_active, prompt_tokens)]
12991316
cache_prefix_len = Llama.longest_token_prefix(
13001317
cache_item.input_ids.tolist(), prompt_tokens
13011318
)
@@ -1633,15 +1650,15 @@ def logit_bias_processor(
16331650
if self.cache:
16341651
if self.verbose:
16351652
print("Llama._create_completion: cache save", file=sys.stderr)
1636-
self.cache[prompt_tokens + completion_tokens] = self.save_state()
1653+
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
16371654
if self.verbose:
16381655
print("Llama._create_completion: cache saved", file=sys.stderr)
16391656
return
16401657

16411658
if self.cache:
16421659
if self.verbose:
16431660
print("Llama._create_completion: cache save", file=sys.stderr)
1644-
self.cache[prompt_tokens + completion_tokens] = self.save_state()
1661+
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
16451662

16461663
text_str = text.decode("utf-8", errors="ignore")
16471664

0 commit comments

Comments
 (0)