Skip to content

Commit 30887d2

Browse files
committed
feat: Update multi LoRA support in high-level Llama wrapper
1 parent 8f43162 commit 30887d2

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

llama_cpp/llama.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def __init__(
9696
# Sampling Params
9797
last_n_tokens_size: int = 64,
9898
# LoRA Params
99-
lora_base: Optional[str] = None,
100-
lora_scale: float = 1.0,
101-
lora_path: Optional[str] = None,
99+
lora_adapters: Optional[Dict[str, float]] = None,
102100
# Backend Params
103101
numa: Union[bool, int] = False,
104102
# Chat Format Params
@@ -174,8 +172,7 @@ def __init__(
174172
offload_kqv: Offload K, Q, V to GPU.
175173
flash_attn: Use flash attention.
176174
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
177-
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
178-
lora_path: Path to a LoRA file to apply to the model.
175+
lora_adapters: Paths to LoRA adapter files and the scale to apply to them at (scale of 0.0 will not be used during inference).
179176
numa: numa policy
180177
chat_format: String specifying the chat format to use when calling create_chat_completion.
181178
chat_handler: Optional chat handler to use when calling create_chat_completion.
@@ -243,7 +240,7 @@ def __init__(
243240
) # keep a reference to the array so it is not gc'd
244241
self.model_params.tensor_split = self._c_tensor_split
245242
self.model_params.vocab_only = vocab_only
246-
self.model_params.use_mmap = use_mmap if lora_path is None else False
243+
self.model_params.use_mmap = use_mmap
247244
self.model_params.use_mlock = use_mlock
248245

249246
# kv_overrides is the original python dict
@@ -355,9 +352,7 @@ def __init__(
355352

356353
self.cache: Optional[BaseLlamaCache] = None
357354

358-
self.lora_base = lora_base
359-
self.lora_scale = lora_scale
360-
self.lora_path = lora_path
355+
self.lora_adapters = lora_adapters
361356

362357
self.spm_infill = spm_infill
363358

@@ -406,32 +401,10 @@ def __init__(
406401
)
407402
)
408403

409-
self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None
410-
411-
if self.lora_path:
412-
self._lora_adapter = llama_cpp.llama_lora_adapter_init(
413-
self._model.model,
414-
self.lora_path.encode("utf-8"),
415-
)
416-
if self._lora_adapter is None:
417-
raise RuntimeError(
418-
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
419-
)
420-
421-
def free_lora_adapter():
422-
if self._lora_adapter is None:
423-
return
424-
llama_cpp.llama_lora_adapter_free(self._lora_adapter)
425-
self._lora_adapter = None
426-
427-
self._stack.callback(free_lora_adapter)
428-
429-
if llama_cpp.llama_lora_adapter_set(
430-
self._ctx.ctx, self._lora_adapter, self.lora_scale
431-
):
432-
raise RuntimeError(
433-
f"Failed to set LoRA adapter from lora path: {self.lora_path}"
434-
)
404+
self._lora_adapters_by_path: Dict[str, internals.LlamaLoraAdapter] = {}
405+
if self.lora_adapters:
406+
for lora_path, scale in self.lora_adapters.items():
407+
self.set_lora_adapter(lora_path, scale, load_if_needed=True)
435408

436409
if self.verbose:
437410
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
@@ -621,6 +594,36 @@ def set_seed(self, seed: int):
621594
seed: The random seed.
622595
"""
623596
self._seed = seed
597+
598+
def set_lora_adapter(self, lora_path: str, scale: float, *, load_if_needed=False):
599+
"""
600+
Set the scale for a LoRA adapter or 0.0 to disable it for inference. If the LoRA adapter file
601+
has previously been loaded then this method will set its scale. If the LoRA adapter file has
602+
not been previously loaded, this method will raise an exception, unless load_if_needed is set.
603+
604+
Args:
605+
lora_path: The path to the LoRA adapter. This path must have been loaded when the `Llama` object was created.
606+
scale: The scaling factor to apply to the LoRA adapter. If 0.0, the LoRA adapter will be disabled so it won't be used during inference.
607+
load_if_needed: Whether or not to load the adapter if it has not been previously been loaded. If True, this
608+
method will attempt to load the adapter from the lora_path if needed. If False, loading an adapter that
609+
hasn't already been loaded will raise an exception.
610+
"""
611+
lora_adapter = self._lora_adapters_by_path.get(lora_path)
612+
if lora_adapter is None:
613+
lora_adapter = internals.LlamaLoraAdapter(
614+
self._model,
615+
lora_path,
616+
verbose=self.verbose,
617+
)
618+
if lora_adapter is None:
619+
raise RuntimeError(
620+
f"Failed to initialize LoRA adapter from lora path: {lora_path}"
621+
)
622+
self._lora_adapters_by_path[lora_path] = lora_adapter
623+
if scale == 0.0:
624+
self._ctx.lora_adapter_remove(lora_adapter) # Safe even if not in context
625+
else:
626+
self._ctx.lora_adapter_set(lora_adapter, scale)
624627

625628
def reset(self):
626629
"""Reset the model state."""
@@ -2096,9 +2099,7 @@ def __getstate__(self):
20962099
# Sampling Params
20972100
last_n_tokens_size=self.last_n_tokens_size,
20982101
# LoRA Params
2099-
lora_base=self.lora_base,
2100-
lora_scale=self.lora_scale,
2101-
lora_path=self.lora_path,
2102+
lora_adapters=self.lora_adapters,
21022103
# Backend Params
21032104
numa=self.numa,
21042105
# Chat Format Params

0 commit comments

Comments
 (0)