Skip to content

Commit f50dd87

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent 00e0243 commit f50dd87

File tree

4 files changed

+145
-38
lines changed

4 files changed

+145
-38
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,6 +112,7 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_transfomers_model.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
@@ -126,6 +127,7 @@ jobs:
126127
run: |
127128
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128129
pytest -sv tests/singlecard/test_offline_inference.py
130+
pytest -sv tests/singlecard/test_transfomers_model.py
129131
pytest -sv tests/ops
130132
else
131133
pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from huggingface_hub import snapshot_download
6+
from vllm.lora.request import LoRARequest
7+
8+
from tests.conftest import VllmRunner
9+
10+
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
11+
12+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
13+
14+
EXPECTED_LORA_OUTPUT = [
15+
"SELECT count(*) FROM singer",
16+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
17+
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
18+
]
19+
20+
21+
@pytest.fixture(scope="session")
22+
def ilama_lora_files():
23+
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
24+
25+
26+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
27+
prompts = [
28+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
29+
PROMPT_TEMPLATE.format(
30+
query=
31+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
32+
),
33+
PROMPT_TEMPLATE.format(
34+
query=
35+
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
36+
),
37+
]
38+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
39+
outputs = llm.generate(
40+
prompts,
41+
sampling_params,
42+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
43+
if lora_id else None)
44+
# Print the outputs.
45+
generated_texts: list[str] = []
46+
for output in outputs:
47+
prompt = output.prompt
48+
generated_text = output.outputs[0].text.strip()
49+
generated_texts.append(generated_text)
50+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
51+
return generated_texts
52+
53+
54+
def test_ilama_lora(ilama_lora_files):
55+
with VllmRunner(model_name=MODEL_PATH,
56+
enable_lora=True,
57+
max_loras=4,
58+
max_model_len=1024,
59+
max_num_seqs=16) as vllm_model:
60+
61+
output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)
62+
for i in range(len(EXPECTED_LORA_OUTPUT)):
63+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
64+
65+
output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)
66+
for i in range(len(EXPECTED_LORA_OUTPUT)):
67+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.v1.sample.sampler import Sampler
5151
from vllm.v1.utils import bind_kv_cache
5252
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
53+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5354

5455
from vllm_ascend.attention.attention import AttentionMaskBuilder
5556
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
102103
yield graph_capture_context
103104

104105

105-
class NPUModelRunner:
106+
class NPUModelRunner(LoRAModelRunnerMixin):
106107

107108
def __init__(self, vllm_config: VllmConfig, device: torch.device):
108109
self.vllm_config = vllm_config
@@ -507,6 +508,10 @@ def _process_reqs(
507508
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
508509
num_tokens)
509510

511+
# Hot-Swap lora model
512+
if self.lora_config:
513+
self.set_active_loras(self.input_batch, num_scheduled_tokens)
514+
510515
# Prepare positions
511516
req_indices = np.repeat(self.arange_np[:num_reqs],
512517
num_scheduled_tokens)
@@ -833,39 +838,55 @@ def _profile_multimodal(self) -> None:
833838

834839
@torch.inference_mode()
835840
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
836-
model = self.model
837-
if self.is_multimodal_model:
838-
input_ids = None
839-
inputs_embeds = self.inputs_embeds[:num_tokens]
840-
else:
841-
input_ids = self.input_ids[:num_tokens]
842-
inputs_embeds = None
841+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
842+
# for dummy run with LoRA so that the num_reqs collectively
843+
# has num_tokens in total.
844+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
845+
max_num_reqs = self.scheduler_config.max_num_seqs
846+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
847+
min_tokens_per_req = num_tokens // num_reqs
848+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
849+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
850+
assert sum(num_scheduled_tokens_list) == num_tokens
851+
assert len(num_scheduled_tokens_list) == num_reqs
852+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
853+
dtype=np.int32)
854+
with self.maybe_dummy_run_with_lora(self.lora_config,
855+
num_scheduled_tokens):
856+
model = self.model
857+
if self.is_multimodal_model:
858+
input_ids = None
859+
inputs_embeds = self.inputs_embeds[:num_tokens]
860+
else:
861+
input_ids = self.input_ids[:num_tokens]
862+
inputs_embeds = None
843863

844-
if self.uses_mrope:
845-
positions = self.mrope_positions[:, :num_tokens]
846-
else:
847-
positions = self.positions[:num_tokens]
864+
if self.uses_mrope:
865+
positions = self.mrope_positions[:, :num_tokens]
866+
else:
867+
positions = self.positions[:num_tokens]
848868

849-
if get_pp_group().is_first_rank:
850-
intermediate_tensors = None
851-
else:
852-
if self.intermediate_tensors is None:
853-
self.intermediate_tensors = (
854-
self.model.make_empty_intermediate_tensors(
855-
batch_size=num_tokens,
856-
dtype=self.dtype,
857-
device=self.device))
858-
intermediate_tensors = IntermediateTensors({
859-
k: v[:num_tokens]
860-
for k, v in self.intermediate_tensors.items()
861-
})
862-
863-
with set_forward_context(None, self.vllm_config):
864-
hidden_states = model(input_ids=input_ids,
865-
positions=positions,
866-
intermediate_tensors=intermediate_tensors,
867-
inputs_embeds=inputs_embeds)
868-
return hidden_states
869+
if get_pp_group().is_first_rank:
870+
intermediate_tensors = None
871+
else:
872+
if self.intermediate_tensors is None:
873+
self.intermediate_tensors = (
874+
self.model.make_empty_intermediate_tensors(
875+
batch_size=num_tokens,
876+
dtype=self.dtype,
877+
device=self.device))
878+
intermediate_tensors = IntermediateTensors({
879+
k: v[:num_tokens]
880+
for k, v in self.intermediate_tensors.items()
881+
})
882+
883+
with set_forward_context(None, self.vllm_config):
884+
hidden_states = model(
885+
input_ids=input_ids,
886+
positions=positions,
887+
intermediate_tensors=intermediate_tensors,
888+
inputs_embeds=inputs_embeds)
889+
return hidden_states
869890

870891
def profile_run(self) -> None:
871892
# Profile with multimodal encoder & encoder cache.
@@ -914,7 +935,11 @@ def load_model(self) -> None:
914935
with DeviceMemoryProfiler() as m: # noqa: SIM117
915936
self.model = get_model(vllm_config=self.vllm_config)
916937
if self.lora_config:
917-
raise ValueError("LoRA model is not supported on NPU now.")
938+
self.model = self.load_lora_model(self.model,
939+
self.model_config,
940+
self.scheduler_config,
941+
self.lora_config,
942+
self.device)
918943
logger.info("Loading model weights took %.4f GB",
919944
m.consumed_memory / float(2**30))
920945

vllm_ascend/worker/worker_v1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
set_custom_all_reduce)
3232
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3333
from vllm.logger import logger
34+
from vllm.lora.request import LoRARequest
3435
from vllm.model_executor import set_random_seed
3536
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
3637
from vllm.v1.core.sched.output import SchedulerOutput
@@ -216,6 +217,18 @@ def profile(self, is_start: bool = True):
216217
else:
217218
self.profiler.stop()
218219

220+
def add_lora(self, lora_request: LoRARequest) -> bool:
221+
return self.model_runner.add_lora(lora_request)
222+
223+
def remove_lora(self, lora_id: int) -> bool:
224+
return self.model_runner.remove_lora(lora_id)
225+
226+
def list_loras(self) -> set[int]:
227+
return self.model_runner.list_loras()
228+
229+
def pin_lora(self, lora_id: int) -> bool:
230+
return self.model_runner.pin_lora(lora_id)
231+
219232
def execute_dummy_batch(self) -> None:
220233
self.model_runner._dummy_run(1)
221234

0 commit comments

Comments
 (0)