Skip to content

Commit d78b819

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent 1e67089 commit d78b819

File tree

4 files changed

+146
-40
lines changed

4 files changed

+146
-40
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,6 +112,7 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_transfomers_model.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
@@ -126,6 +127,7 @@ jobs:
126127
run: |
127128
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128129
pytest -sv tests/singlecard/test_offline_inference.py
130+
pytest -sv tests/singlecard/test_transfomers_model.py
129131
pytest -sv tests/ops
130132
else
131133
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from huggingface_hub import snapshot_download
6+
from vllm.lora.request import LoRARequest
7+
8+
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
EXPECTED_LORA_OUTPUT = [
13+
"SELECT count(*) FROM singer",
14+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
15+
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
16+
]
17+
18+
19+
@pytest.fixture(scope="session")
20+
def ilama_lora_files():
21+
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
22+
23+
24+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
25+
prompts = [
26+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
27+
PROMPT_TEMPLATE.format(
28+
query=
29+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
30+
),
31+
PROMPT_TEMPLATE.format(
32+
query=
33+
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
34+
),
35+
]
36+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
37+
outputs = llm.generate(
38+
prompts,
39+
sampling_params,
40+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
41+
if lora_id else None)
42+
# Print the outputs.
43+
generated_texts: list[str] = []
44+
for output in outputs:
45+
prompt = output.prompt
46+
generated_text = output.outputs[0].text.strip()
47+
generated_texts.append(generated_text)
48+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
49+
return generated_texts
50+
51+
52+
def test_ilama_lora(ilama_lora_files):
53+
llm = vllm.LLM(MODEL_PATH,
54+
max_model_len=1024,
55+
enable_lora=True,
56+
max_loras=4,
57+
max_lora_rank=16,
58+
trust_remote_code=True,
59+
enable_chunked_prefill=True)
60+
61+
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
62+
for i in range(len(EXPECTED_LORA_OUTPUT)):
63+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
64+
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
65+
for i in range(len(EXPECTED_LORA_OUTPUT)):
66+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.v1.sample.sampler import Sampler
5151
from vllm.v1.utils import bind_kv_cache
5252
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
53+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5354

5455
from vllm_ascend.attention.attention import AttentionMaskBuilder
5556
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
102103
yield graph_capture_context
103104

104105

105-
class NPUModelRunner:
106+
class NPUModelRunner(LoRAModelRunnerMixin):
106107

107108
def __init__(self, vllm_config: VllmConfig, device: torch.device):
108109
self.vllm_config = vllm_config
@@ -507,8 +508,8 @@ def _process_reqs(
507508
assert total_num_scheduled_tokens > 0
508509
num_reqs = self.input_batch.num_reqs
509510
assert num_reqs > 0
510-
if (self.use_aclgraph and
511-
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
511+
if (self.use_aclgraph and total_num_scheduled_tokens
512+
<= self.aclgraph_batch_sizes[-1]):
512513
# Add padding to the batch size.
513514
num_input_tokens = self.vllm_config.pad_for_cudagraph(
514515
total_num_scheduled_tokens)
@@ -534,6 +535,10 @@ def _process_reqs(
534535
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
535536
num_tokens)
536537

538+
# Hot-Swap lora model
539+
if self.lora_config:
540+
self.set_active_loras(self.input_batch, num_scheduled_tokens)
541+
537542
# Prepare positions
538543
req_indices = np.repeat(self.arange_np[:num_reqs],
539544
num_scheduled_tokens)
@@ -857,39 +862,55 @@ def _profile_multimodal(self) -> None:
857862

858863
@torch.inference_mode()
859864
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
860-
model = self.model
861-
if self.is_multimodal_model:
862-
input_ids = None
863-
inputs_embeds = self.inputs_embeds[:num_tokens]
864-
else:
865-
input_ids = self.input_ids[:num_tokens]
866-
inputs_embeds = None
865+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
866+
# for dummy run with LoRA so that the num_reqs collectively
867+
# has num_tokens in total.
868+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
869+
max_num_reqs = self.scheduler_config.max_num_seqs
870+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
871+
min_tokens_per_req = num_tokens // num_reqs
872+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
873+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
874+
assert sum(num_scheduled_tokens_list) == num_tokens
875+
assert len(num_scheduled_tokens_list) == num_reqs
876+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
877+
dtype=np.int32)
878+
with self.maybe_dummy_run_with_lora(self.lora_config,
879+
num_scheduled_tokens):
880+
model = self.model
881+
if self.is_multimodal_model:
882+
input_ids = None
883+
inputs_embeds = self.inputs_embeds[:num_tokens]
884+
else:
885+
input_ids = self.input_ids[:num_tokens]
886+
inputs_embeds = None
867887

868-
if self.uses_mrope:
869-
positions = self.mrope_positions[:, :num_tokens]
870-
else:
871-
positions = self.positions[:num_tokens]
888+
if self.uses_mrope:
889+
positions = self.mrope_positions[:, :num_tokens]
890+
else:
891+
positions = self.positions[:num_tokens]
872892

873-
if get_pp_group().is_first_rank:
874-
intermediate_tensors = None
875-
else:
876-
if self.intermediate_tensors is None:
877-
self.intermediate_tensors = (
878-
self.model.make_empty_intermediate_tensors(
879-
batch_size=num_tokens,
880-
dtype=self.dtype,
881-
device=self.device))
882-
intermediate_tensors = IntermediateTensors({
883-
k: v[:num_tokens]
884-
for k, v in self.intermediate_tensors.items()
885-
})
886-
887-
with set_forward_context(None, self.vllm_config):
888-
hidden_states = model(input_ids=input_ids,
889-
positions=positions,
890-
intermediate_tensors=intermediate_tensors,
891-
inputs_embeds=inputs_embeds)
892-
return hidden_states
893+
if get_pp_group().is_first_rank:
894+
intermediate_tensors = None
895+
else:
896+
if self.intermediate_tensors is None:
897+
self.intermediate_tensors = (
898+
self.model.make_empty_intermediate_tensors(
899+
batch_size=num_tokens,
900+
dtype=self.dtype,
901+
device=self.device))
902+
intermediate_tensors = IntermediateTensors({
903+
k: v[:num_tokens]
904+
for k, v in self.intermediate_tensors.items()
905+
})
906+
907+
with set_forward_context(None, self.vllm_config):
908+
hidden_states = model(
909+
input_ids=input_ids,
910+
positions=positions,
911+
intermediate_tensors=intermediate_tensors,
912+
inputs_embeds=inputs_embeds)
913+
return hidden_states
893914

894915
def profile_run(self) -> None:
895916
# Profile with multimodal encoder & encoder cache.
@@ -938,7 +959,11 @@ def load_model(self) -> None:
938959
with DeviceMemoryProfiler() as m: # noqa: SIM117
939960
self.model = get_model(vllm_config=self.vllm_config)
940961
if self.lora_config:
941-
raise ValueError("LoRA model is not supported on NPU now.")
962+
self.model = self.load_lora_model(self.model,
963+
self.model_config,
964+
self.scheduler_config,
965+
self.lora_config,
966+
self.device)
942967
logger.info("Loading model weights took %.4f GB",
943968
m.consumed_memory / float(2**30))
944969

vllm_ascend/worker/worker_v1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
set_custom_all_reduce)
3232
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3333
from vllm.logger import logger
34+
from vllm.lora.request import LoRARequest
3435
from vllm.model_executor import set_random_seed
3536
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
3637
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ def profile(self, is_start: bool = True):
216217
else:
217218
self.profiler.stop()
218219

220+
def add_lora(self, lora_request: LoRARequest) -> bool:
221+
return self.model_runner.add_lora(lora_request)
222+
223+
def remove_lora(self, lora_id: int) -> bool:
224+
return self.model_runner.remove_lora(lora_id)
225+
226+
def list_loras(self) -> set[int]:
227+
return self.model_runner.list_loras()
228+
229+
def pin_lora(self, lora_id: int) -> bool:
230+
return self.model_runner.pin_lora(lora_id)
231+
219232
def execute_dummy_batch(self) -> None:
220233
self.model_runner._dummy_run(1)
221234

0 commit comments

Comments
 (0)