Skip to content

Commit 9d8e9a3

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent cdece86 commit 9d8e9a3

8 files changed

+1272
-33
lines changed

tests/singlecard/test_baichuan.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from vllm.distributed import cleanup_dist_env_and_memory
6+
from vllm.lora.request import LoRARequest
7+
8+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
13+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
14+
prompts = [
15+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
16+
PROMPT_TEMPLATE.format(
17+
query=
18+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
19+
),
20+
PROMPT_TEMPLATE.format(
21+
query=
22+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
23+
),
24+
]
25+
print(prompts)
26+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
27+
outputs = llm.generate(
28+
prompts,
29+
sampling_params,
30+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
31+
if lora_id else None)
32+
# Print the outputs.
33+
generated_texts: list[str] = []
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text.strip()
37+
generated_texts.append(generated_text)
38+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
39+
return generated_texts
40+
41+
42+
def test_baichuan_lora(baichuan_lora_files):
43+
llm = vllm.LLM(MODEL_PATH,
44+
max_model_len=1024,
45+
enable_lora=True,
46+
max_loras=4,
47+
max_lora_rank=64,
48+
trust_remote_code=True)
49+
50+
expected_lora_output = [
51+
"SELECT count(*) FROM singer",
52+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
53+
"SELECT name , country , age FROM singer ORDER BY age ASC",
54+
]
55+
56+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
57+
for i in range(len(expected_lora_output)):
58+
assert output1[i] == expected_lora_output[i]
59+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
60+
for i in range(len(expected_lora_output)):
61+
assert output2[i] == expected_lora_output[i]
62+
63+
64+
@pytest.mark.parametrize("fully_sharded", [True, False])
65+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
66+
num_gpus_available, fully_sharded):
67+
if num_gpus_available < 4:
68+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
69+
70+
llm_tp1 = vllm.LLM(MODEL_PATH,
71+
enable_lora=True,
72+
max_num_seqs=16,
73+
max_loras=4,
74+
max_lora_rank=64,
75+
trust_remote_code=True,
76+
fully_sharded_loras=fully_sharded)
77+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
78+
79+
del llm_tp1
80+
cleanup_dist_env_and_memory()
81+
82+
llm_tp2 = vllm.LLM(MODEL_PATH,
83+
enable_lora=True,
84+
max_num_seqs=16,
85+
max_loras=4,
86+
max_lora_rank=64,
87+
tensor_parallel_size=2,
88+
trust_remote_code=True,
89+
fully_sharded_loras=fully_sharded)
90+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
91+
92+
del llm_tp2
93+
cleanup_dist_env_and_memory()
94+
95+
assert output_tp1 == output_tp2
96+
97+
llm_tp4 = vllm.LLM(MODEL_PATH,
98+
enable_lora=True,
99+
max_num_seqs=16,
100+
max_loras=4,
101+
max_lora_rank=64,
102+
tensor_parallel_size=4,
103+
trust_remote_code=True,
104+
fully_sharded_loras=fully_sharded)
105+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
106+
107+
del llm_tp4
108+
cleanup_dist_env_and_memory()
109+
110+
assert output_tp1 == output_tp4
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
5+
VllmConfig)
6+
from vllm.lora.request import LoRARequest
7+
from vllm.sampling_params import SamplingParams
8+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
9+
from vllm.v1.engine.processor import Processor
10+
11+
12+
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
13+
sql_lora_files):
14+
"""
15+
Test that we properly resolve the range of allowed token ids for lora
16+
adapters that define additional tokens.
17+
"""
18+
19+
# Setup a base model compatible with the sql_lora_files adapter and
20+
# a known number of tokens in the base model.
21+
model_config = ModelConfig(
22+
model=llama_2_7b_base_huggingface_id,
23+
tokenizer=llama_2_7b_base_huggingface_id,
24+
tokenizer_mode="auto",
25+
)
26+
27+
vllm_config = VllmConfig(
28+
model_config=model_config,
29+
cache_config=CacheConfig(),
30+
device_config=DeviceConfig(),
31+
lora_config=LoRAConfig(),
32+
)
33+
34+
tokenizer = init_tokenizer_from_configs(
35+
model_config=vllm_config.model_config,
36+
scheduler_config=vllm_config.scheduler_config,
37+
lora_config=vllm_config.lora_config)
38+
processor = Processor(vllm_config, tokenizer)
39+
40+
lora_request = LoRARequest("1", 1, str(sql_lora_files))
41+
request_id = "1"
42+
prompt = "a prompt"
43+
44+
# tokens added in the lora adapter should not raise an error
45+
lora_token_ids = [32000, 32001, 32002, 32003]
46+
processor.process_inputs(
47+
request_id,
48+
prompt,
49+
params=SamplingParams(allowed_token_ids=lora_token_ids),
50+
lora_request=lora_request)
51+
52+
# tokens in the base model should not raise an error
53+
base_token_ids = [1000, 1001, 1002, 1003]
54+
processor.process_inputs(
55+
request_id,
56+
prompt,
57+
params=SamplingParams(allowed_token_ids=base_token_ids),
58+
lora_request=lora_request)
59+
60+
# tokens not in the lora adapter should raise an error
61+
invalid_token_ids = [35000, 35001, 35002, 35003]
62+
with pytest.raises(ValueError):
63+
processor.process_inputs(
64+
request_id,
65+
prompt,
66+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
67+
lora_request=lora_request)
68+
69+
# tokens in the lora adapter with no lora request should raise an error
70+
with pytest.raises(ValueError):
71+
processor.process_inputs(
72+
request_id,
73+
prompt,
74+
params=SamplingParams(allowed_token_ids=lora_token_ids),
75+
)
76+
77+
78+
def test_allowed_token_ids_with_lora_adapter_no_vocab(
79+
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
80+
"""
81+
Test that we properly resolve the range of allowed token ids for lora
82+
adapters that do not define additional tokens.
83+
"""
84+
85+
# Setup a base model compatible with the qwen25vl_lora_files adapter and
86+
# a known number of tokens in the base model.
87+
model_config = ModelConfig(
88+
model=qwen25vl_base_huggingface_id,
89+
tokenizer=qwen25vl_base_huggingface_id,
90+
tokenizer_mode="auto",
91+
)
92+
93+
vllm_config = VllmConfig(
94+
model_config=model_config,
95+
cache_config=CacheConfig(),
96+
device_config=DeviceConfig(),
97+
lora_config=LoRAConfig(),
98+
)
99+
100+
tokenizer = init_tokenizer_from_configs(
101+
model_config=vllm_config.model_config,
102+
scheduler_config=vllm_config.scheduler_config,
103+
lora_config=vllm_config.lora_config)
104+
processor = Processor(vllm_config, tokenizer)
105+
106+
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
107+
request_id = "1"
108+
prompt = "a prompt"
109+
110+
# tokens in the base model should not raise an error
111+
base_token_ids = [1000, 1001, 1002, 1003]
112+
processor.process_inputs(
113+
request_id,
114+
prompt,
115+
params=SamplingParams(allowed_token_ids=base_token_ids),
116+
lora_request=lora_request)
117+
118+
# tokens in the base model with no lora request should not raise an error
119+
base_token_ids = [1000, 1001, 1002, 1003]
120+
processor.process_inputs(
121+
request_id,
122+
prompt,
123+
params=SamplingParams(allowed_token_ids=base_token_ids),
124+
)
125+
126+
# tokens not in the base model should raise an error
127+
invalid_token_ids = [200000, 200001, 200002, 200003]
128+
with pytest.raises(ValueError):
129+
processor.process_inputs(
130+
request_id,
131+
prompt,
132+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
133+
lora_request=lora_request)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from vllm.lora.models import LoRAModel
5+
from vllm.lora.peft_helper import PEFTHelper
6+
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
7+
from vllm.model_executor.models.utils import WeightsMapper
8+
9+
lora_lst = [
10+
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
11+
]
12+
BAICHUAN_LORA_MODULES = [
13+
"W_pack",
14+
"o_proj",
15+
"gate_up_proj",
16+
"down_proj",
17+
]
18+
19+
20+
@pytest.mark.parametrize("lora_name", lora_lst)
21+
def test_load_checkpoints(
22+
lora_name,
23+
baichuan_lora_files,
24+
baichuan_zero_lora_files,
25+
baichuan_regex_lora_files,
26+
chatglm3_lora_files,
27+
):
28+
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
29+
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
30+
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
31+
expected_lora_modules: list[str] = []
32+
for module in BAICHUAN_LORA_MODULES:
33+
if module in packed_modules_mapping:
34+
expected_lora_modules.extend(packed_modules_mapping[module])
35+
else:
36+
expected_lora_modules.append(module)
37+
if lora_name == "baichuan7B":
38+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
39+
max_position_embeddings=4096)
40+
# For the baichuan7B model, load it's LoRA,
41+
# and the test should pass.
42+
LoRAModel.from_local_checkpoint(
43+
baichuan_lora_files,
44+
expected_lora_modules,
45+
peft_helper=peft_helper,
46+
lora_model_id=1,
47+
device="cpu",
48+
embedding_modules=embedding_modules,
49+
embedding_padding_modules=embed_padding_modules)
50+
elif lora_name == "baichuan7B-zero":
51+
# Test that the target_modules contain prefix
52+
# such as "model.layers.0.self_atten.W_pack", and
53+
# the test should pass.
54+
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
55+
max_position_embeddings=4096)
56+
LoRAModel.from_local_checkpoint(
57+
baichuan_zero_lora_files,
58+
expected_lora_modules,
59+
peft_helper=peft_helper,
60+
lora_model_id=1,
61+
device="cpu",
62+
embedding_modules=embedding_modules,
63+
embedding_padding_modules=embed_padding_modules)
64+
elif lora_name == "baichuan7B-zero-regex":
65+
# Test that the `target_modules` in the form of regular expressions,
66+
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
67+
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
68+
max_position_embeddings=4096)
69+
LoRAModel.from_local_checkpoint(
70+
baichuan_regex_lora_files,
71+
expected_lora_modules,
72+
peft_helper=peft_helper,
73+
lora_model_id=1,
74+
device="cpu",
75+
embedding_modules=embedding_modules,
76+
embedding_padding_modules=embed_padding_modules)
77+
else:
78+
# For the baichuan7B model, load chatglm3-6b's LoRA,
79+
# and the test should raise the following error.
80+
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
81+
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
82+
max_position_embeddings=4096)
83+
with pytest.raises(ValueError, match=expected_error):
84+
LoRAModel.from_local_checkpoint(
85+
chatglm3_lora_files,
86+
expected_lora_modules,
87+
peft_helper=peft_helper,
88+
lora_model_id=1,
89+
device="cpu",
90+
embedding_modules=embedding_modules,
91+
embedding_padding_modules=embed_padding_modules)
92+
93+
94+
def test_lora_weights_mapping(baichuan_lora_files):
95+
96+
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
97+
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
98+
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
99+
expected_lora_modules: list[str] = []
100+
for module in BAICHUAN_LORA_MODULES:
101+
if module in packed_modules_mapping:
102+
expected_lora_modules.extend(packed_modules_mapping[module])
103+
else:
104+
expected_lora_modules.append(module)
105+
106+
hf_to_vllm_mapper = WeightsMapper(
107+
orig_to_new_prefix={
108+
"model.": "language_model.model.",
109+
},
110+
orig_to_new_substr={
111+
".layers.": ".baichuan_layers.",
112+
},
113+
)
114+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
115+
max_position_embeddings=4096)
116+
lora_model = LoRAModel.from_local_checkpoint(
117+
baichuan_lora_files,
118+
expected_lora_modules,
119+
peft_helper=peft_helper,
120+
lora_model_id=1,
121+
device="cpu",
122+
embedding_modules=embedding_modules,
123+
embedding_padding_modules=embed_padding_modules,
124+
weights_mapper=hf_to_vllm_mapper,
125+
)
126+
for name in lora_model.loras:
127+
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
128+
assert ".baichuan_layers." in name

0 commit comments

Comments
 (0)