Skip to content

Commit 244539a

Browse files
committed
feat: Add multi LoRA support to internal model
1 parent e712cff commit 244539a

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

llama_cpp/_internals.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,18 @@ def kv_cache_seq_keep(self, seq_id: int):
285285
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
286286
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
287287

288+
def lora_adapter_set(self, adapter: LlamaLoraAdapter, scale: float):
289+
return_code = llama_cpp.llama_lora_adapter_set(self.ctx, adapter.lora_adapter, scale)
290+
if return_code != 0:
291+
raise RuntimeError(f"lora_adapter_set returned {return_code}")
292+
293+
def lora_adapter_remove(self, adapter: LlamaLoraAdapter) -> bool:
294+
return_code = llama_cpp.llama_lora_adapter_remove(self.ctx, adapter.lora_adapter)
295+
return return_code != 0
296+
297+
def lora_adapter_clear(self):
298+
llama_cpp.llama_lora_adapter_clear(self.ctx)
299+
288300
def get_state_size(self) -> int:
289301
return llama_cpp.llama_get_state_size(self.ctx)
290302

@@ -861,3 +873,45 @@ def close(self):
861873

862874
def __del__(self):
863875
self.close()
876+
877+
class LlamaLoraAdapter:
878+
"""Intermediate Python wrapper for a llama.cpp llama_lora_adapter.
879+
NOTE: For stability it's recommended you use the Llama class instead."""
880+
881+
def __init__(
882+
self,
883+
model: LlamaModel,
884+
lora_path: str,
885+
*,
886+
verbose: bool = True,
887+
):
888+
self.model = model
889+
self.lora_path = lora_path
890+
891+
lora_adapter = None
892+
893+
if not os.path.exists(lora_path):
894+
raise ValueError(f"LoRA adapter path does not exist: {lora_path}")
895+
896+
with suppress_stdout_stderr(disable=verbose):
897+
lora_adapter = llama_cpp.llama_lora_adapter_init(
898+
self.model.model,
899+
self.lora_path.encode("utf-8"),
900+
)
901+
902+
if lora_adapter is None:
903+
raise RuntimeError(
904+
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
905+
)
906+
907+
# The llama_lora_adapter will be freed by the llama_model as part of its
908+
# lifecycle. The llama_model destructor destroys each llama_lora_adapter,
909+
# and the destructor for llama_lora_adapter calls llama_lora_adapter_free.
910+
# All we do here is clear the wrapped reference when the LlamaModel wrapper
911+
# is closed, so that the LlamaLoraAdapter wrapper reference is cleared to
912+
# when the llama_lora_adapters are freed.
913+
def clear_lora_adapter():
914+
self.lora_adapter = None
915+
self.model._exit_stack.callback(clear_lora_adapter)
916+
917+
self.lora_adapter = lora_adapter

0 commit comments

Comments
 (0)