Skip to content

Commit 6214b15

Browse files
horenbergerbvowelparrot
authored andcommitted
add LoRA loading for the LlamaCpp LLM (#3363)
First PR, let me know if this needs anything like unit tests, reformatting, etc. Seemed pretty straightforward to implement. Only hitch was that mmap needs to be disabled when loading LoRAs or else you segfault.
1 parent a21fc19 commit 6214b15

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

langchain/llms/llamacpp.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ class LlamaCpp(LLM):
2727
model_path: str
2828
"""The path to the Llama model file."""
2929

30+
lora_base: Optional[str] = None
31+
"""The path to the Llama LoRA base model."""
32+
33+
lora_path: Optional[str] = None
34+
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
35+
3036
n_ctx: int = Field(512, alias="n_ctx")
3137
"""Token context window."""
3238

@@ -87,13 +93,18 @@ class LlamaCpp(LLM):
8793
last_n_tokens_size: Optional[int] = 64
8894
"""The number of tokens to look back when applying the repeat_penalty."""
8995

96+
use_mmap: Optional[bool] = True
97+
"""Whether to keep the model loaded in RAM"""
98+
9099
streaming: bool = True
91100
"""Whether to stream the results, token by token."""
92101

93102
@root_validator()
94103
def validate_environment(cls, values: Dict) -> Dict:
95104
"""Validate that llama-cpp-python library is installed."""
96105
model_path = values["model_path"]
106+
lora_path = values["lora_path"]
107+
lora_base = values["lora_base"]
97108
n_ctx = values["n_ctx"]
98109
n_parts = values["n_parts"]
99110
seed = values["seed"]
@@ -103,13 +114,16 @@ def validate_environment(cls, values: Dict) -> Dict:
103114
use_mlock = values["use_mlock"]
104115
n_threads = values["n_threads"]
105116
n_batch = values["n_batch"]
117+
use_mmap = values["use_mmap"]
106118
last_n_tokens_size = values["last_n_tokens_size"]
107119

108120
try:
109121
from llama_cpp import Llama
110122

111123
values["client"] = Llama(
112124
model_path=model_path,
125+
lora_base=lora_base,
126+
lora_path=lora_path,
113127
n_ctx=n_ctx,
114128
n_parts=n_parts,
115129
seed=seed,
@@ -119,6 +133,7 @@ def validate_environment(cls, values: Dict) -> Dict:
119133
use_mlock=use_mlock,
120134
n_threads=n_threads,
121135
n_batch=n_batch,
136+
use_mmap=use_mmap,
122137
last_n_tokens_size=last_n_tokens_size,
123138
)
124139
except ImportError:

0 commit comments

Comments
 (0)