Skip to content

Commit 8e48f1c

Browse files
committed
fix: Caching for hot-swapping LoRA adapters
1 parent 30887d2 commit 8e48f1c

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ High-level Python bindings for llama.cpp.
2222
- __call__
2323
- create_chat_completion
2424
- create_chat_completion_openai_v1
25+
- set_lora_adapter_scale
2526
- set_cache
2627
- save_state
2728
- load_state

llama_cpp/llama.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
List,
1919
Literal,
2020
Optional,
21+
Tuple,
2122
Union,
2223
Generator,
2324
Sequence,
@@ -352,7 +353,9 @@ def __init__(
352353

353354
self.cache: Optional[BaseLlamaCache] = None
354355

355-
self.lora_adapters = lora_adapters
356+
self.lora_adapters = (
357+
lora_adapters if lora_adapters is None else {}
358+
)
356359

357360
self.spm_infill = spm_infill
358361

@@ -401,10 +404,13 @@ def __init__(
401404
)
402405
)
403406

404-
self._lora_adapters_by_path: Dict[str, internals.LlamaLoraAdapter] = {}
405-
if self.lora_adapters:
406-
for lora_path, scale in self.lora_adapters.items():
407-
self.set_lora_adapter(lora_path, scale, load_if_needed=True)
407+
# Dict from LoRA path to wrapper
408+
self._lora_adapters_paths: Dict[str, internals.LlamaLoraAdapter] = {}
409+
# Immutable value representing active adapters for use as a key
410+
self._lora_adapters_active: Tuple[Tuple[str, float]] = ()
411+
412+
for lora_path, scale in self.lora_adapters.copy().items():
413+
self.set_lora_adapter_scale(lora_path, scale, load_if_needed=True)
408414

409415
if self.verbose:
410416
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
@@ -426,6 +432,7 @@ def __init__(
426432
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
427433

428434
self.n_tokens = 0
435+
self.tokens_lora_adapters: Tuple[Tuple[str, float]] = () # Adapters that processed tokens
429436
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
430437
self.scores: npt.NDArray[np.single] = np.ndarray(
431438
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
@@ -595,7 +602,7 @@ def set_seed(self, seed: int):
595602
"""
596603
self._seed = seed
597604

598-
def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False):
605+
def set_lora_adapter_scale(self, lora_path: str, scale: float, *, load_if_needed=False):
599606
"""
600607
Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601608
has previously been loaded then this method will set its scale. If the LoRA adapter file has
@@ -608,7 +615,8 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
608615
method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609616
hasn't already been loaded will raise an exception.
610617
"""
611-
lora_adapter = self._lora_adapters_by_path.get(lora_path)
618+
# Load adapter if needed (even if scale 0.0)
619+
lora_adapter = self._lora_adapters_paths.get(lora_path)
612620
if lora_adapter is None:
613621
lora_adapter = internals.LlamaLoraAdapter(
614622
self._model,
@@ -619,15 +627,24 @@ def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False
619627
raise RuntimeError(
620628
f"Failed to initialize LoRA adapter from lora path: {lora_path}"
621629
)
622-
self._lora_adapters_by_path[lora_path] = lora_adapter
630+
self._lora_adapters_paths[lora_path] = lora_adapter
631+
623632
if scale == 0.0:
624-
self._ctx.lora_adapter_remove(lora_adapter) # Safe even if not in context
633+
# Remove from context; safe to call even if not in context
634+
self._ctx.lora_adapter_remove(lora_adapter)
625635
else:
636+
# Set scale in context
626637
self._ctx.lora_adapter_set(lora_adapter, scale)
627638

639+
self.lora_adapters[lora_path] = scale
640+
self._lora_adapters_active = tuple(sorted(
641+
filter(lambda path_scale: path_scale[1] != 0.0, self.lora_adapters.items())
642+
))
643+
628644
def reset(self):
629645
"""Reset the model state."""
630646
self.n_tokens = 0
647+
self.tokens_lora_adapters = self._lora_adapters_active
631648

632649
def eval(self, tokens: Sequence[int]):
633650
"""Evaluate a list of tokens.
@@ -879,7 +896,7 @@ def generate(
879896
)
880897

881898
# Check for kv cache prefix match
882-
if reset and self.n_tokens > 0:
899+
if reset and self.n_tokens > 0 and self.tokens_lora_adapters == self._lora_adapters_active:
883900
longest_prefix = 0
884901
for a, b in zip(self._input_ids, tokens[:-1]):
885902
if a == b:
@@ -1296,7 +1313,7 @@ def logit_bias_processor(
12961313

12971314
if self.cache:
12981315
try:
1299-
cache_item = self.cache[prompt_tokens]
1316+
cache_item = self.cache[(self._lora_adapters_active, prompt_tokens)]
13001317
cache_prefix_len = Llama.longest_token_prefix(
13011318
cache_item.input_ids.tolist(), prompt_tokens
13021319
)
@@ -1634,15 +1651,15 @@ def logit_bias_processor(
16341651
if self.cache:
16351652
if self.verbose:
16361653
print("Llama._create_completion: cache save", file=sys.stderr)
1637-
self.cache[prompt_tokens + completion_tokens] = self.save_state()
1654+
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
16381655
if self.verbose:
16391656
print("Llama._create_completion: cache saved", file=sys.stderr)
16401657
return
16411658

16421659
if self.cache:
16431660
if self.verbose:
16441661
print("Llama._create_completion: cache save", file=sys.stderr)
1645-
self.cache[prompt_tokens + completion_tokens] = self.save_state()
1662+
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
16461663

16471664
text_str = text.decode("utf-8", errors="ignore")
16481665

0 commit comments

Comments
 (0)