Skip to content

Commit 2dc2c4e

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent cdece86 commit 2dc2c4e

8 files changed

+1278
-33
lines changed

tests/singlecard/test_baichuan.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
import vllm
6+
from vllm.distributed import cleanup_dist_env_and_memory
7+
from vllm.lora.request import LoRARequest
8+
9+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
10+
11+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
12+
13+
14+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
15+
prompts = [
16+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
17+
PROMPT_TEMPLATE.format(
18+
query=
19+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
20+
),
21+
PROMPT_TEMPLATE.format(
22+
query=
23+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
24+
),
25+
]
26+
print(prompts)
27+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
28+
outputs = llm.generate(
29+
prompts,
30+
sampling_params,
31+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
32+
if lora_id else None)
33+
# Print the outputs.
34+
generated_texts: list[str] = []
35+
for output in outputs:
36+
prompt = output.prompt
37+
generated_text = output.outputs[0].text.strip()
38+
generated_texts.append(generated_text)
39+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
40+
return generated_texts
41+
42+
43+
def test_baichuan_lora(baichuan_lora_files):
44+
llm = vllm.LLM(MODEL_PATH,
45+
max_model_len=1024,
46+
enable_lora=True,
47+
max_loras=4,
48+
max_lora_rank=64,
49+
trust_remote_code=True)
50+
51+
expected_lora_output = [
52+
"SELECT count(*) FROM singer",
53+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
54+
"SELECT name , country , age FROM singer ORDER BY age ASC",
55+
]
56+
57+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
58+
for i in range(len(expected_lora_output)):
59+
assert output1[i] == expected_lora_output[i]
60+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
61+
for i in range(len(expected_lora_output)):
62+
assert output2[i] == expected_lora_output[i]
63+
64+
65+
@pytest.mark.parametrize("fully_sharded", [True, False])
66+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
67+
num_gpus_available, fully_sharded):
68+
if num_gpus_available < 4:
69+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
70+
71+
llm_tp1 = vllm.LLM(MODEL_PATH,
72+
enable_lora=True,
73+
max_num_seqs=16,
74+
max_loras=4,
75+
max_lora_rank=64,
76+
trust_remote_code=True,
77+
fully_sharded_loras=fully_sharded)
78+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
79+
80+
del llm_tp1
81+
cleanup_dist_env_and_memory()
82+
83+
llm_tp2 = vllm.LLM(MODEL_PATH,
84+
enable_lora=True,
85+
max_num_seqs=16,
86+
max_loras=4,
87+
max_lora_rank=64,
88+
tensor_parallel_size=2,
89+
trust_remote_code=True,
90+
fully_sharded_loras=fully_sharded)
91+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
92+
93+
del llm_tp2
94+
cleanup_dist_env_and_memory()
95+
96+
assert output_tp1 == output_tp2
97+
98+
llm_tp4 = vllm.LLM(MODEL_PATH,
99+
enable_lora=True,
100+
max_num_seqs=16,
101+
max_loras=4,
102+
max_lora_rank=64,
103+
tensor_parallel_size=4,
104+
trust_remote_code=True,
105+
fully_sharded_loras=fully_sharded)
106+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
107+
108+
del llm_tp4
109+
cleanup_dist_env_and_memory()
110+
111+
assert output_tp1 == output_tp4
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
6+
VllmConfig)
7+
from vllm.lora.request import LoRARequest
8+
from vllm.sampling_params import SamplingParams
9+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
10+
from vllm.v1.engine.processor import Processor
11+
12+
13+
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
14+
sql_lora_files):
15+
"""
16+
Test that we properly resolve the range of allowed token ids for lora
17+
adapters that define additional tokens.
18+
"""
19+
20+
# Setup a base model compatible with the sql_lora_files adapter and
21+
# a known number of tokens in the base model.
22+
model_config = ModelConfig(
23+
model=llama_2_7b_base_huggingface_id,
24+
tokenizer=llama_2_7b_base_huggingface_id,
25+
tokenizer_mode="auto",
26+
)
27+
28+
vllm_config = VllmConfig(
29+
model_config=model_config,
30+
cache_config=CacheConfig(),
31+
device_config=DeviceConfig(),
32+
lora_config=LoRAConfig(),
33+
)
34+
35+
tokenizer = init_tokenizer_from_configs(
36+
model_config=vllm_config.model_config,
37+
scheduler_config=vllm_config.scheduler_config,
38+
lora_config=vllm_config.lora_config)
39+
processor = Processor(vllm_config, tokenizer)
40+
41+
lora_request = LoRARequest("1", 1, str(sql_lora_files))
42+
request_id = "1"
43+
prompt = "a prompt"
44+
45+
# tokens added in the lora adapter should not raise an error
46+
lora_token_ids = [32000, 32001, 32002, 32003]
47+
processor.process_inputs(
48+
request_id,
49+
prompt,
50+
params=SamplingParams(allowed_token_ids=lora_token_ids),
51+
lora_request=lora_request)
52+
53+
# tokens in the base model should not raise an error
54+
base_token_ids = [1000, 1001, 1002, 1003]
55+
processor.process_inputs(
56+
request_id,
57+
prompt,
58+
params=SamplingParams(allowed_token_ids=base_token_ids),
59+
lora_request=lora_request)
60+
61+
# tokens not in the lora adapter should raise an error
62+
invalid_token_ids = [35000, 35001, 35002, 35003]
63+
with pytest.raises(ValueError):
64+
processor.process_inputs(
65+
request_id,
66+
prompt,
67+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
68+
lora_request=lora_request)
69+
70+
# tokens in the lora adapter with no lora request should raise an error
71+
with pytest.raises(ValueError):
72+
processor.process_inputs(
73+
request_id,
74+
prompt,
75+
params=SamplingParams(allowed_token_ids=lora_token_ids),
76+
)
77+
78+
79+
def test_allowed_token_ids_with_lora_adapter_no_vocab(
80+
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
81+
"""
82+
Test that we properly resolve the range of allowed token ids for lora
83+
adapters that do not define additional tokens.
84+
"""
85+
86+
# Setup a base model compatible with the qwen25vl_lora_files adapter and
87+
# a known number of tokens in the base model.
88+
model_config = ModelConfig(
89+
model=qwen25vl_base_huggingface_id,
90+
tokenizer=qwen25vl_base_huggingface_id,
91+
tokenizer_mode="auto",
92+
)
93+
94+
vllm_config = VllmConfig(
95+
model_config=model_config,
96+
cache_config=CacheConfig(),
97+
device_config=DeviceConfig(),
98+
lora_config=LoRAConfig(),
99+
)
100+
101+
tokenizer = init_tokenizer_from_configs(
102+
model_config=vllm_config.model_config,
103+
scheduler_config=vllm_config.scheduler_config,
104+
lora_config=vllm_config.lora_config)
105+
processor = Processor(vllm_config, tokenizer)
106+
107+
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
108+
request_id = "1"
109+
prompt = "a prompt"
110+
111+
# tokens in the base model should not raise an error
112+
base_token_ids = [1000, 1001, 1002, 1003]
113+
processor.process_inputs(
114+
request_id,
115+
prompt,
116+
params=SamplingParams(allowed_token_ids=base_token_ids),
117+
lora_request=lora_request)
118+
119+
# tokens in the base model with no lora request should not raise an error
120+
base_token_ids = [1000, 1001, 1002, 1003]
121+
processor.process_inputs(
122+
request_id,
123+
prompt,
124+
params=SamplingParams(allowed_token_ids=base_token_ids),
125+
)
126+
127+
# tokens not in the base model should raise an error
128+
invalid_token_ids = [200000, 200001, 200002, 200003]
129+
with pytest.raises(ValueError):
130+
processor.process_inputs(
131+
request_id,
132+
prompt,
133+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
134+
lora_request=lora_request)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
from vllm.lora.models import LoRAModel
6+
from vllm.lora.peft_helper import PEFTHelper
7+
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
8+
from vllm.model_executor.models.utils import WeightsMapper
9+
10+
lora_lst = [
11+
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
12+
]
13+
BAICHUAN_LORA_MODULES = [
14+
"W_pack",
15+
"o_proj",
16+
"gate_up_proj",
17+
"down_proj",
18+
]
19+
20+
21+
@pytest.mark.parametrize("lora_name", lora_lst)
22+
def test_load_checkpoints(
23+
lora_name,
24+
baichuan_lora_files,
25+
baichuan_zero_lora_files,
26+
baichuan_regex_lora_files,
27+
chatglm3_lora_files,
28+
):
29+
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
30+
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
31+
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
32+
expected_lora_modules: list[str] = []
33+
for module in BAICHUAN_LORA_MODULES:
34+
if module in packed_modules_mapping:
35+
expected_lora_modules.extend(packed_modules_mapping[module])
36+
else:
37+
expected_lora_modules.append(module)
38+
if lora_name == "baichuan7B":
39+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
40+
max_position_embeddings=4096)
41+
# For the baichuan7B model, load it's LoRA,
42+
# and the test should pass.
43+
LoRAModel.from_local_checkpoint(
44+
baichuan_lora_files,
45+
expected_lora_modules,
46+
peft_helper=peft_helper,
47+
lora_model_id=1,
48+
device="cpu",
49+
embedding_modules=embedding_modules,
50+
embedding_padding_modules=embed_padding_modules)
51+
elif lora_name == "baichuan7B-zero":
52+
# Test that the target_modules contain prefix
53+
# such as "model.layers.0.self_atten.W_pack", and
54+
# the test should pass.
55+
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
56+
max_position_embeddings=4096)
57+
LoRAModel.from_local_checkpoint(
58+
baichuan_zero_lora_files,
59+
expected_lora_modules,
60+
peft_helper=peft_helper,
61+
lora_model_id=1,
62+
device="cpu",
63+
embedding_modules=embedding_modules,
64+
embedding_padding_modules=embed_padding_modules)
65+
elif lora_name == "baichuan7B-zero-regex":
66+
# Test that the `target_modules` in the form of regular expressions,
67+
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
68+
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
69+
max_position_embeddings=4096)
70+
LoRAModel.from_local_checkpoint(
71+
baichuan_regex_lora_files,
72+
expected_lora_modules,
73+
peft_helper=peft_helper,
74+
lora_model_id=1,
75+
device="cpu",
76+
embedding_modules=embedding_modules,
77+
embedding_padding_modules=embed_padding_modules)
78+
else:
79+
# For the baichuan7B model, load chatglm3-6b's LoRA,
80+
# and the test should raise the following error.
81+
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
82+
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
83+
max_position_embeddings=4096)
84+
with pytest.raises(ValueError, match=expected_error):
85+
LoRAModel.from_local_checkpoint(
86+
chatglm3_lora_files,
87+
expected_lora_modules,
88+
peft_helper=peft_helper,
89+
lora_model_id=1,
90+
device="cpu",
91+
embedding_modules=embedding_modules,
92+
embedding_padding_modules=embed_padding_modules)
93+
94+
95+
def test_lora_weights_mapping(baichuan_lora_files):
96+
97+
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
98+
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
99+
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
100+
expected_lora_modules: list[str] = []
101+
for module in BAICHUAN_LORA_MODULES:
102+
if module in packed_modules_mapping:
103+
expected_lora_modules.extend(packed_modules_mapping[module])
104+
else:
105+
expected_lora_modules.append(module)
106+
107+
hf_to_vllm_mapper = WeightsMapper(
108+
orig_to_new_prefix={
109+
"model.": "language_model.model.",
110+
},
111+
orig_to_new_substr={
112+
".layers.": ".baichuan_layers.",
113+
},
114+
)
115+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
116+
max_position_embeddings=4096)
117+
lora_model = LoRAModel.from_local_checkpoint(
118+
baichuan_lora_files,
119+
expected_lora_modules,
120+
peft_helper=peft_helper,
121+
lora_model_id=1,
122+
device="cpu",
123+
embedding_modules=embedding_modules,
124+
embedding_padding_modules=embed_padding_modules,
125+
weights_mapper=hf_to_vllm_mapper,
126+
)
127+
for name in lora_model.loras:
128+
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
129+
assert ".baichuan_layers." in name

0 commit comments

Comments
 (0)