Skip to content

Commit 3b99874

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent cdece86 commit 3b99874

9 files changed

+1279
-38
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,6 +112,7 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_lora_functions.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
@@ -126,6 +127,7 @@ jobs:
126127
run: |
127128
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128129
pytest -sv tests/singlecard/test_offline_inference.py
130+
pytest -sv tests/singlecard/test_lora_functions.py
129131
pytest -sv tests/ops
130132
else
131133
pytest -sv tests/multicard/test_offline_inference_distributed.py

tests/singlecard/test_baichuan.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from vllm.distributed import cleanup_dist_env_and_memory
6+
from vllm.lora.request import LoRARequest
7+
8+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
13+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
14+
prompts = [
15+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
16+
PROMPT_TEMPLATE.format(
17+
query=
18+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
19+
),
20+
PROMPT_TEMPLATE.format(
21+
query=
22+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
23+
),
24+
]
25+
print(prompts)
26+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
27+
outputs = llm.generate(
28+
prompts,
29+
sampling_params,
30+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
31+
if lora_id else None)
32+
# Print the outputs.
33+
generated_texts: list[str] = []
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text.strip()
37+
generated_texts.append(generated_text)
38+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
39+
return generated_texts
40+
41+
42+
def test_baichuan_lora(baichuan_lora_files):
43+
llm = vllm.LLM(MODEL_PATH,
44+
max_model_len=1024,
45+
enable_lora=True,
46+
max_loras=4,
47+
max_lora_rank=64,
48+
trust_remote_code=True)
49+
50+
expected_lora_output = [
51+
"SELECT count(*) FROM singer",
52+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
53+
"SELECT name , country , age FROM singer ORDER BY age ASC",
54+
]
55+
56+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
57+
for i in range(len(expected_lora_output)):
58+
assert output1[i] == expected_lora_output[i]
59+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
60+
for i in range(len(expected_lora_output)):
61+
assert output2[i] == expected_lora_output[i]
62+
63+
64+
@pytest.mark.parametrize("fully_sharded", [True, False])
65+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
66+
num_gpus_available, fully_sharded):
67+
if num_gpus_available < 4:
68+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
69+
70+
llm_tp1 = vllm.LLM(MODEL_PATH,
71+
enable_lora=True,
72+
max_num_seqs=16,
73+
max_loras=4,
74+
max_lora_rank=64,
75+
trust_remote_code=True,
76+
fully_sharded_loras=fully_sharded)
77+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
78+
79+
del llm_tp1
80+
cleanup_dist_env_and_memory()
81+
82+
llm_tp2 = vllm.LLM(MODEL_PATH,
83+
enable_lora=True,
84+
max_num_seqs=16,
85+
max_loras=4,
86+
max_lora_rank=64,
87+
tensor_parallel_size=2,
88+
trust_remote_code=True,
89+
fully_sharded_loras=fully_sharded)
90+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
91+
92+
del llm_tp2
93+
cleanup_dist_env_and_memory()
94+
95+
assert output_tp1 == output_tp2
96+
97+
llm_tp4 = vllm.LLM(MODEL_PATH,
98+
enable_lora=True,
99+
max_num_seqs=16,
100+
max_loras=4,
101+
max_lora_rank=64,
102+
tensor_parallel_size=4,
103+
trust_remote_code=True,
104+
fully_sharded_loras=fully_sharded)
105+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
106+
107+
del llm_tp4
108+
cleanup_dist_env_and_memory()
109+
110+
assert output_tp1 == output_tp4
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
5+
VllmConfig)
6+
from vllm.lora.request import LoRARequest
7+
from vllm.sampling_params import SamplingParams
8+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
9+
from vllm.v1.engine.processor import Processor
10+
11+
12+
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
13+
sql_lora_files):
14+
"""
15+
Test that we properly resolve the range of allowed token ids for lora
16+
adapters that define additional tokens.
17+
"""
18+
19+
# Setup a base model compatible with the sql_lora_files adapter and
20+
# a known number of tokens in the base model.
21+
model_config = ModelConfig(
22+
model=llama_2_7b_base_huggingface_id,
23+
tokenizer=llama_2_7b_base_huggingface_id,
24+
tokenizer_mode="auto",
25+
)
26+
27+
vllm_config = VllmConfig(
28+
model_config=model_config,
29+
cache_config=CacheConfig(),
30+
device_config=DeviceConfig(),
31+
lora_config=LoRAConfig(),
32+
)
33+
34+
tokenizer = init_tokenizer_from_configs(
35+
model_config=vllm_config.model_config,
36+
scheduler_config=vllm_config.scheduler_config,
37+
lora_config=vllm_config.lora_config)
38+
processor = Processor(vllm_config, tokenizer)
39+
40+
lora_request = LoRARequest("1", 1, str(sql_lora_files))
41+
request_id = "1"
42+
prompt = "a prompt"
43+
44+
# tokens added in the lora adapter should not raise an error
45+
lora_token_ids = [32000, 32001, 32002, 32003]
46+
processor.process_inputs(
47+
request_id,
48+
prompt,
49+
params=SamplingParams(allowed_token_ids=lora_token_ids),
50+
lora_request=lora_request)
51+
52+
# tokens in the base model should not raise an error
53+
base_token_ids = [1000, 1001, 1002, 1003]
54+
processor.process_inputs(
55+
request_id,
56+
prompt,
57+
params=SamplingParams(allowed_token_ids=base_token_ids),
58+
lora_request=lora_request)
59+
60+
# tokens not in the lora adapter should raise an error
61+
invalid_token_ids = [35000, 35001, 35002, 35003]
62+
with pytest.raises(ValueError):
63+
processor.process_inputs(
64+
request_id,
65+
prompt,
66+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
67+
lora_request=lora_request)
68+
69+
# tokens in the lora adapter with no lora request should raise an error
70+
with pytest.raises(ValueError):
71+
processor.process_inputs(
72+
request_id,
73+
prompt,
74+
params=SamplingParams(allowed_token_ids=lora_token_ids),
75+
)
76+
77+
78+
def test_allowed_token_ids_with_lora_adapter_no_vocab(
79+
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
80+
"""
81+
Test that we properly resolve the range of allowed token ids for lora
82+
adapters that do not define additional tokens.
83+
"""
84+
85+
# Setup a base model compatible with the qwen25vl_lora_files adapter and
86+
# a known number of tokens in the base model.
87+
model_config = ModelConfig(
88+
model=qwen25vl_base_huggingface_id,
89+
tokenizer=qwen25vl_base_huggingface_id,
90+
tokenizer_mode="auto",
91+
)
92+
93+
vllm_config = VllmConfig(
94+
model_config=model_config,
95+
cache_config=CacheConfig(),
96+
device_config=DeviceConfig(),
97+
lora_config=LoRAConfig(),
98+
)
99+
100+
tokenizer = init_tokenizer_from_configs(
101+
model_config=vllm_config.model_config,
102+
scheduler_config=vllm_config.scheduler_config,
103+
lora_config=vllm_config.lora_config)
104+
processor = Processor(vllm_config, tokenizer)
105+
106+
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
107+
request_id = "1"
108+
prompt = "a prompt"
109+
110+
# tokens in the base model should not raise an error
111+
base_token_ids = [1000, 1001, 1002, 1003]
112+
processor.process_inputs(
113+
request_id,
114+
prompt,
115+
params=SamplingParams(allowed_token_ids=base_token_ids),
116+
lora_request=lora_request)
117+
118+
# tokens in the base model with no lora request should not raise an error
119+
base_token_ids = [1000, 1001, 1002, 1003]
120+
processor.process_inputs(
121+
request_id,
122+
prompt,
123+
params=SamplingParams(allowed_token_ids=base_token_ids),
124+
)
125+
126+
# tokens not in the base model should raise an error
127+
invalid_token_ids = [200000, 200001, 200002, 200003]
128+
with pytest.raises(ValueError):
129+
processor.process_inputs(
130+
request_id,
131+
prompt,
132+
params=SamplingParams(allowed_token_ids=invalid_token_ids),
133+
lora_request=lora_request)

0 commit comments

Comments
 (0)