Skip to content

Commit adef148

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent 217211d commit adef148

File tree

4 files changed

+198
-38
lines changed

4 files changed

+198
-38
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,6 +112,7 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_baichuan.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
@@ -126,6 +127,7 @@ jobs:
126127
run: |
127128
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128129
pytest -sv tests/singlecard/test_offline_inference.py
130+
pytest -sv tests/singlecard/test_baichuan.py
129131
pytest -sv tests/ops
130132
else
131133
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py

tests/singlecard/test_baichuan.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from vllm.distributed import cleanup_dist_env_and_memory
6+
from vllm.lora.request import LoRARequest
7+
from huggingface_hub import snapshot_download
8+
9+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
10+
11+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
12+
13+
14+
@pytest.fixture(scope="session")
15+
def baichuan_lora_files():
16+
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
17+
18+
19+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
20+
prompts = [
21+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
22+
PROMPT_TEMPLATE.format(
23+
query=
24+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
25+
),
26+
PROMPT_TEMPLATE.format(
27+
query=
28+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
29+
),
30+
]
31+
print(prompts)
32+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
33+
outputs = llm.generate(
34+
prompts,
35+
sampling_params,
36+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
37+
if lora_id else None)
38+
# Print the outputs.
39+
generated_texts: list[str] = []
40+
for output in outputs:
41+
prompt = output.prompt
42+
generated_text = output.outputs[0].text.strip()
43+
generated_texts.append(generated_text)
44+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
45+
return generated_texts
46+
47+
48+
def test_baichuan_lora(baichuan_lora_files):
49+
llm = vllm.LLM(MODEL_PATH,
50+
max_model_len=1024,
51+
enable_lora=True,
52+
max_loras=4,
53+
max_lora_rank=64,
54+
trust_remote_code=True)
55+
56+
expected_lora_output = [
57+
"SELECT count(*) FROM singer",
58+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
59+
"SELECT name , country , age FROM singer ORDER BY age ASC",
60+
]
61+
62+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
63+
for i in range(len(expected_lora_output)):
64+
assert output1[i] == expected_lora_output[i]
65+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
66+
for i in range(len(expected_lora_output)):
67+
assert output2[i] == expected_lora_output[i]
68+
69+
70+
@pytest.fixture(scope="session")
71+
def num_gpus_available():
72+
return 1
73+
74+
@pytest.mark.parametrize("fully_sharded", [True, False])
75+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
76+
num_gpus_available, fully_sharded):
77+
if num_gpus_available < 4:
78+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
79+
80+
llm_tp1 = vllm.LLM(MODEL_PATH,
81+
enable_lora=True,
82+
max_num_seqs=16,
83+
max_loras=4,
84+
max_lora_rank=64,
85+
trust_remote_code=True,
86+
fully_sharded_loras=fully_sharded)
87+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
88+
89+
del llm_tp1
90+
cleanup_dist_env_and_memory()
91+
92+
llm_tp2 = vllm.LLM(MODEL_PATH,
93+
enable_lora=True,
94+
max_num_seqs=16,
95+
max_loras=4,
96+
max_lora_rank=64,
97+
tensor_parallel_size=2,
98+
trust_remote_code=True,
99+
fully_sharded_loras=fully_sharded)
100+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
101+
102+
del llm_tp2
103+
cleanup_dist_env_and_memory()
104+
105+
assert output_tp1 == output_tp2
106+
107+
llm_tp4 = vllm.LLM(MODEL_PATH,
108+
enable_lora=True,
109+
max_num_seqs=16,
110+
max_loras=4,
111+
max_lora_rank=64,
112+
tensor_parallel_size=4,
113+
trust_remote_code=True,
114+
fully_sharded_loras=fully_sharded)
115+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
116+
117+
del llm_tp4
118+
cleanup_dist_env_and_memory()
119+
120+
assert output_tp1 == output_tp4

vllm_ascend/worker/model_runner_v1.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.v1.sample.sampler import Sampler
5151
from vllm.v1.utils import bind_kv_cache
5252
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
53+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5354

5455
from vllm_ascend.attention.attention import AttentionMaskBuilder
5556
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
102103
yield graph_capture_context
103104

104105

105-
class NPUModelRunner:
106+
class NPUModelRunner(LoRAModelRunnerMixin):
106107

107108
def __init__(self, vllm_config: VllmConfig, device: torch.device):
108109
self.vllm_config = vllm_config
@@ -534,6 +535,10 @@ def _process_reqs(
534535
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
535536
num_tokens)
536537

538+
# Hot-Swap lora model
539+
if self.lora_config:
540+
self.set_active_loras(self.input_batch, num_scheduled_tokens)
541+
537542
# Prepare positions
538543
req_indices = np.repeat(self.arange_np[:num_reqs],
539544
num_scheduled_tokens)
@@ -857,39 +862,55 @@ def _profile_multimodal(self) -> None:
857862

858863
@torch.inference_mode()
859864
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
860-
model = self.model
861-
if self.is_multimodal_model:
862-
input_ids = None
863-
inputs_embeds = self.inputs_embeds[:num_tokens]
864-
else:
865-
input_ids = self.input_ids[:num_tokens]
866-
inputs_embeds = None
865+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
866+
# for dummy run with LoRA so that the num_reqs collectively
867+
# has num_tokens in total.
868+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
869+
max_num_reqs = self.scheduler_config.max_num_seqs
870+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
871+
min_tokens_per_req = num_tokens // num_reqs
872+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
873+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
874+
assert sum(num_scheduled_tokens_list) == num_tokens
875+
assert len(num_scheduled_tokens_list) == num_reqs
876+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
877+
dtype=np.int32)
878+
with self.maybe_dummy_run_with_lora(self.lora_config,
879+
num_scheduled_tokens):
880+
model = self.model
881+
if self.is_multimodal_model:
882+
input_ids = None
883+
inputs_embeds = self.inputs_embeds[:num_tokens]
884+
else:
885+
input_ids = self.input_ids[:num_tokens]
886+
inputs_embeds = None
867887

868-
if self.uses_mrope:
869-
positions = self.mrope_positions[:, :num_tokens]
870-
else:
871-
positions = self.positions[:num_tokens]
888+
if self.uses_mrope:
889+
positions = self.mrope_positions[:, :num_tokens]
890+
else:
891+
positions = self.positions[:num_tokens]
872892

873-
if get_pp_group().is_first_rank:
874-
intermediate_tensors = None
875-
else:
876-
if self.intermediate_tensors is None:
877-
self.intermediate_tensors = (
878-
self.model.make_empty_intermediate_tensors(
879-
batch_size=num_tokens,
880-
dtype=self.dtype,
881-
device=self.device))
882-
intermediate_tensors = IntermediateTensors({
883-
k: v[:num_tokens]
884-
for k, v in self.intermediate_tensors.items()
885-
})
886-
887-
with set_forward_context(None, self.vllm_config):
888-
hidden_states = model(input_ids=input_ids,
889-
positions=positions,
890-
intermediate_tensors=intermediate_tensors,
891-
inputs_embeds=inputs_embeds)
892-
return hidden_states
893+
if get_pp_group().is_first_rank:
894+
intermediate_tensors = None
895+
else:
896+
if self.intermediate_tensors is None:
897+
self.intermediate_tensors = (
898+
self.model.make_empty_intermediate_tensors(
899+
batch_size=self.max_num_tokens,
900+
dtype=self.model_config.dtype,
901+
device=self.device))
902+
intermediate_tensors = IntermediateTensors({
903+
k: v[:num_tokens]
904+
for k, v in self.intermediate_tensors.items()
905+
})
906+
907+
with set_forward_context(None, self.vllm_config):
908+
hidden_states = model(
909+
input_ids=input_ids,
910+
positions=positions.to(self.device),
911+
intermediate_tensors=intermediate_tensors,
912+
inputs_embeds=inputs_embeds)
913+
return hidden_states
893914

894915
def profile_run(self) -> None:
895916
# Profile with multimodal encoder & encoder cache.
@@ -938,7 +959,11 @@ def load_model(self) -> None:
938959
with DeviceMemoryProfiler() as m: # noqa: SIM117
939960
self.model = get_model(vllm_config=self.vllm_config)
940961
if self.lora_config:
941-
raise ValueError("LoRA model is not supported on NPU now.")
962+
self.model = self.load_lora_model(self.model,
963+
self.model_config,
964+
self.scheduler_config,
965+
self.lora_config,
966+
self.device)
942967
logger.info("Loading model weights took %.4f GB",
943968
m.consumed_memory / float(2**30))
944969

vllm_ascend/worker/worker_v1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
set_custom_all_reduce)
3232
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3333
from vllm.logger import logger
34+
from vllm.lora.request import LoRARequest
3435
from vllm.model_executor import set_random_seed
3536
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
3637
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ def profile(self, is_start: bool = True):
216217
else:
217218
self.profiler.stop()
218219

220+
def add_lora(self, lora_request: LoRARequest) -> bool:
221+
return self.model_runner.add_lora(lora_request)
222+
223+
def remove_lora(self, lora_id: int) -> bool:
224+
return self.model_runner.remove_lora(lora_id)
225+
226+
def list_loras(self) -> set[int]:
227+
return self.model_runner.list_loras()
228+
229+
def pin_lora(self, lora_id: int) -> bool:
230+
return self.model_runner.pin_lora(lora_id)
231+
219232
def execute_dummy_batch(self) -> None:
220233
self.model_runner._dummy_run(1)
221234

0 commit comments

Comments
 (0)