Skip to content

Commit 2877e6e

Browse files
committed
feat: Multi-LoRA common args / low level API
1 parent 8e48f1c commit 2877e6e

File tree

2 files changed

+64
-34
lines changed

2 files changed

+64
-34
lines changed

examples/low_level_api/common.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import re
44

55
from dataclasses import dataclass, field
6-
from typing import List
7-
8-
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
6+
from typing import List, Sequence, Tuple
7+
import typing
98

9+
# Based on https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp
10+
# and https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp
1011

1112
@dataclass
1213
class GptParams:
@@ -40,8 +41,8 @@ class GptParams:
4041
input_suffix: str = ""
4142
antiprompt: List[str] = field(default_factory=list)
4243

43-
lora_adapter: str = ""
44-
lora_base: str = ""
44+
lora: List[str] = None
45+
lora_scaled: List[Tuple[str, float]] = None
4546

4647
memory_f16: bool = True
4748
random_prompt: bool = False
@@ -257,16 +258,56 @@ def gpt_params_parse(argv=None):
257258
parser.add_argument(
258259
"--lora",
259260
type=str,
260-
default="",
261-
help="apply LoRA adapter (implies --no-mmap)",
262-
dest="lora_adapter",
263-
)
264-
parser.add_argument(
265-
"--lora-base",
266-
type=str,
267-
default="",
268-
help="optional model to use as a base for the layers modified by the LoRA adapter",
269-
dest="lora_base",
261+
action="append",
262+
default=[],
263+
help="path to LoRA adapter (can be repeated to use multiple adapters)",
264+
metavar="FNAME",
265+
dest="lora",
266+
)
267+
268+
class MultiTupleAction(argparse.Action):
269+
"""Action for handling multiple arguments as tuples with type conversion"""
270+
def __init__(self,
271+
option_strings: Sequence[str],
272+
dest: str,
273+
nargs: int = None,
274+
type: Tuple = None,
275+
metavar: Tuple = None,
276+
**kwargs):
277+
self.tuple_type = type
278+
super().__init__(
279+
option_strings=option_strings,
280+
dest=dest,
281+
type=str, # We will fix
282+
nargs=nargs,
283+
metavar=metavar,
284+
**kwargs
285+
)
286+
287+
def __call__(self, parser, namespace, values, option_string=None):
288+
if len(values) != self.nargs:
289+
parser.error(
290+
f'{option_string} requires {len(self.metavar)} arguments: '
291+
f'{" ".join(self.metavar)}'
292+
)
293+
294+
converted_values = tuple(value_type(value) for value_type, value in zip(typing.get_args(self.tuple_type), values))
295+
# Initialize list if needed
296+
if not hasattr(namespace, self.dest):
297+
setattr(namespace, self.dest, [])
298+
299+
# Add the converted tuple to the list
300+
getattr(namespace, self.dest).append(converted_values)
301+
302+
parser.add_argument(
303+
"--lora-scaled",
304+
action=MultiTupleAction,
305+
nargs=2,
306+
type=Tuple[str, float],
307+
help="path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
308+
metavar=('FNAME', 'SCALE'),
309+
dest='lora_scaled',
310+
default=[],
270311
)
271312

272313
parser.add_argument(
@@ -375,9 +416,6 @@ def gpt_params_parse(argv=None):
375416
delattr(args, "logit_bias_str")
376417
params = GptParams(**vars(args))
377418

378-
if params.lora_adapter:
379-
params.use_mmap = False
380-
381419
if logit_bias_str != None:
382420
for i in logit_bias_str:
383421
if m := re.match(r"(\d+)([-+]\d+)", i):

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,14 @@ def __init__(self, params: GptParams) -> None:
9393
if self.params.ignore_eos:
9494
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
9595

96-
if len(self.params.lora_adapter) > 0:
97-
if (
98-
llama_cpp.llama_apply_lora_from_file(
99-
self.ctx,
100-
self.params.lora_adapter.encode("utf8"),
101-
(
102-
self.params.lora_base.encode("utf8")
103-
if len(self.params.lora_base) > 0
104-
else None
105-
),
106-
self.params.n_threads,
107-
)
108-
!= 0
109-
):
110-
print("error: failed to apply lora adapter")
111-
return
96+
for lora_path, scale in [(pth, 1.0) for pth in self.params.lora] + self.params.lora_scaled:
97+
lora_adapter = llama_cpp.llama_lora_adapter_init(
98+
self.model,
99+
lora_path.encode("utf8"))
100+
if lora_adapter is None:
101+
raise RuntimeError(f"error: failed to load lora adapter '{lora_path}'")
102+
if scale != 0.0:
103+
llama_cpp.llama_lora_adapter_set(self.ctx, lora_adapter, scale)
112104

113105
print(file=sys.stderr)
114106
print(

0 commit comments

Comments
 (0)