Skip to content

Commit ab8953b

Browse files
committed
[Feature] add V1 Engine LoRA support (vllm-project#801)
What this PR does / why we need it? According to this RFC vllm-project#396 and this vllm-project#448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse <szxfml@gmail.com>
1 parent cdece86 commit ab8953b

File tree

5 files changed

+315
-38
lines changed

5 files changed

+315
-38
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ jobs:
5151
vllm_verison: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
54-
${{
55-
matrix.os == 'linux-arm64-npu-4'
56-
&& github.event.pull_request.number
57-
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
54+
${{
55+
matrix.os == 'linux-arm64-npu-4'
56+
&& github.event.pull_request.number
57+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -112,6 +112,7 @@ jobs:
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114114
pytest -sv tests/singlecard/test_offline_inference.py
115+
pytest -sv tests/singlecard/test_lora_functions.py
115116
pytest -sv tests/ops
116117
pytest -sv tests/compile
117118
else
@@ -126,6 +127,7 @@ jobs:
126127
run: |
127128
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128129
pytest -sv tests/singlecard/test_offline_inference.py
130+
pytest -sv tests/singlecard/test_lora_functions.py
129131
pytest -sv tests/ops
130132
else
131133
pytest -sv tests/multicard/test_offline_inference_distributed.py

tests/singlecard/test_baichuan.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import vllm
5+
from vllm.distributed import cleanup_dist_env_and_memory
6+
from vllm.lora.request import LoRARequest
7+
8+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
13+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
14+
prompts = [
15+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
16+
PROMPT_TEMPLATE.format(
17+
query=
18+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
19+
),
20+
PROMPT_TEMPLATE.format(
21+
query=
22+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
23+
),
24+
]
25+
print(prompts)
26+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
27+
outputs = llm.generate(
28+
prompts,
29+
sampling_params,
30+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
31+
if lora_id else None)
32+
# Print the outputs.
33+
generated_texts: list[str] = []
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text.strip()
37+
generated_texts.append(generated_text)
38+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
39+
return generated_texts
40+
41+
42+
def test_baichuan_lora(baichuan_lora_files):
43+
llm = vllm.LLM(MODEL_PATH,
44+
max_model_len=1024,
45+
enable_lora=True,
46+
max_loras=4,
47+
max_lora_rank=64,
48+
trust_remote_code=True)
49+
50+
expected_lora_output = [
51+
"SELECT count(*) FROM singer",
52+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
53+
"SELECT name , country , age FROM singer ORDER BY age ASC",
54+
]
55+
56+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
57+
for i in range(len(expected_lora_output)):
58+
assert output1[i] == expected_lora_output[i]
59+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
60+
for i in range(len(expected_lora_output)):
61+
assert output2[i] == expected_lora_output[i]
62+
63+
64+
@pytest.mark.parametrize("fully_sharded", [True, False])
65+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
66+
num_gpus_available, fully_sharded):
67+
if num_gpus_available < 4:
68+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
69+
70+
llm_tp1 = vllm.LLM(MODEL_PATH,
71+
enable_lora=True,
72+
max_num_seqs=16,
73+
max_loras=4,
74+
max_lora_rank=64,
75+
trust_remote_code=True,
76+
fully_sharded_loras=fully_sharded)
77+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
78+
79+
del llm_tp1
80+
cleanup_dist_env_and_memory()
81+
82+
llm_tp2 = vllm.LLM(MODEL_PATH,
83+
enable_lora=True,
84+
max_num_seqs=16,
85+
max_loras=4,
86+
max_lora_rank=64,
87+
tensor_parallel_size=2,
88+
trust_remote_code=True,
89+
fully_sharded_loras=fully_sharded)
90+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
91+
92+
del llm_tp2
93+
cleanup_dist_env_and_memory()
94+
95+
assert output_tp1 == output_tp2
96+
97+
llm_tp4 = vllm.LLM(MODEL_PATH,
98+
enable_lora=True,
99+
max_num_seqs=16,
100+
max_loras=4,
101+
max_lora_rank=64,
102+
tensor_parallel_size=4,
103+
trust_remote_code=True,
104+
fully_sharded_loras=fully_sharded)
105+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
106+
107+
del llm_tp4
108+
cleanup_dist_env_and_memory()
109+
110+
assert output_tp1 == output_tp4
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
4+
"""
5+
6+
import os
7+
8+
import pytest
9+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
10+
from vllm.engine.llm_engine import LLMEngine
11+
from vllm.lora.request import LoRARequest
12+
13+
MODEL_PATH = "NousResearch/Llama-2-7b-hf"
14+
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
15+
LORA_RANK = 8
16+
17+
18+
def make_lora_request(lora_id: int):
19+
return LoRARequest(lora_name=f"{lora_id}",
20+
lora_int_id=lora_id,
21+
lora_path=LORA_MODULE_PATH)
22+
23+
24+
def test_lora_functions_sync():
25+
26+
max_loras = 4
27+
# Create engine in eager-mode. Due to high max_loras, the CI can
28+
# OOM during cuda-graph capture.
29+
engine_args = EngineArgs(model=MODEL_PATH,
30+
enable_lora=True,
31+
max_loras=max_loras,
32+
max_lora_rank=LORA_RANK,
33+
max_model_len=128,
34+
gpu_memory_utilization=0.8,
35+
enforce_eager=True)
36+
37+
llm = LLMEngine.from_engine_args(engine_args)
38+
39+
def run_check(fn, args, expected: list):
40+
fn(args)
41+
assert set(llm.list_loras()) == set(expected)
42+
43+
run_check(llm.add_lora, make_lora_request(1), [1])
44+
run_check(llm.add_lora, make_lora_request(2), [1, 2])
45+
46+
# Pin LoRA 1 and test that it is never removed on subsequent adds.
47+
run_check(llm.pin_lora, 1, [1, 2])
48+
run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
49+
run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
50+
run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
51+
run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
52+
run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
53+
run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
54+
run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
55+
run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
56+
57+
# Remove LoRA 1 and continue adding.
58+
run_check(llm.remove_lora, 1, [8, 9, 10])
59+
run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
60+
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
61+
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
62+
63+
# Remove all LoRAs
64+
run_check(llm.remove_lora, 13, [12, 10, 11])
65+
run_check(llm.remove_lora, 12, [10, 11])
66+
run_check(llm.remove_lora, 11, [10])
67+
run_check(llm.remove_lora, 10, [])
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_lora_functions_async():
72+
73+
if os.getenv("VLLM_USE_V1") == "0":
74+
pytest.skip(
75+
reason=
76+
"V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions")
77+
78+
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
79+
# environment variable. reload vllm.enging.async_llm_engine as
80+
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
81+
# env var.
82+
import importlib
83+
84+
import vllm.engine.async_llm_engine
85+
importlib.reload(vllm.engine.async_llm_engine)
86+
from vllm.entrypoints.openai.api_server import \
87+
build_async_engine_client_from_engine_args
88+
89+
max_loras = 4
90+
engine_args = AsyncEngineArgs(model=MODEL_PATH,
91+
enable_lora=True,
92+
max_loras=max_loras,
93+
max_lora_rank=LORA_RANK,
94+
max_model_len=128,
95+
gpu_memory_utilization=0.8,
96+
enforce_eager=True)
97+
98+
async def run_check(fn, args, expected: list):
99+
await fn(args)
100+
assert set(await llm.list_loras()) == set(expected)
101+
102+
async with build_async_engine_client_from_engine_args(engine_args) as llm:
103+
await run_check(llm.add_lora, make_lora_request(1), [1])
104+
await run_check(llm.add_lora, make_lora_request(2), [1, 2])
105+
106+
# Pin LoRA 1 and test that it is never removed on subsequent adds.
107+
await run_check(llm.pin_lora, 1, [1, 2])
108+
await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
109+
await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
110+
await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
111+
await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
112+
await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
113+
await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
114+
await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
115+
await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
116+
117+
# Remove LoRA 1 and continue adding.
118+
await run_check(llm.remove_lora, 1, [8, 9, 10])
119+
await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
120+
await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
121+
await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
122+
123+
# Remove all LoRAs
124+
await run_check(llm.remove_lora, 13, [12, 10, 11])
125+
await run_check(llm.remove_lora, 12, [10, 11])
126+
await run_check(llm.remove_lora, 11, [10])
127+
await run_check(llm.remove_lora, 10, [])

vllm_ascend/worker/model_runner_v1.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from vllm.v1.sample.sampler import Sampler
5252
from vllm.v1.utils import bind_kv_cache
5353
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
54+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
5455

5556
from vllm_ascend.attention.attention import AttentionMaskBuilder
5657
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -101,7 +102,7 @@ def graph_capture(device: torch.device):
101102
yield graph_capture_context
102103

103104

104-
class NPUModelRunner:
105+
class NPUModelRunner(LoRAModelRunnerMixin):
105106

106107
def __init__(self, vllm_config: VllmConfig, device: torch.device):
107108
self.vllm_config = vllm_config
@@ -511,6 +512,10 @@ def _process_reqs(
511512
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
512513
num_tokens)
513514

515+
# Hot-Swap lora model
516+
if self.lora_config:
517+
self.set_active_loras(self.input_batch, num_scheduled_tokens)
518+
514519
# Prepare positions
515520
req_indices = np.repeat(self.arange_np[:num_reqs],
516521
num_scheduled_tokens)
@@ -794,39 +799,55 @@ def _profile_multimodal(self) -> None:
794799

795800
@torch.inference_mode()
796801
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
797-
model = self.model
798-
if self.is_multimodal_model:
799-
input_ids = None
800-
inputs_embeds = self.inputs_embeds[:num_tokens]
801-
else:
802-
input_ids = self.input_ids[:num_tokens]
803-
inputs_embeds = None
802+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
803+
# for dummy run with LoRA so that the num_reqs collectively
804+
# has num_tokens in total.
805+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
806+
max_num_reqs = self.scheduler_config.max_num_seqs
807+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
808+
min_tokens_per_req = num_tokens // num_reqs
809+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
810+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
811+
assert sum(num_scheduled_tokens_list) == num_tokens
812+
assert len(num_scheduled_tokens_list) == num_reqs
813+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
814+
dtype=np.int32)
815+
with self.maybe_dummy_run_with_lora(self.lora_config,
816+
num_scheduled_tokens):
817+
model = self.model
818+
if self.is_multimodal_model:
819+
input_ids = None
820+
inputs_embeds = self.inputs_embeds[:num_tokens]
821+
else:
822+
input_ids = self.input_ids[:num_tokens]
823+
inputs_embeds = None
804824

805-
if self.uses_mrope:
806-
positions = self.mrope_positions[:, :num_tokens]
807-
else:
808-
positions = self.positions[:num_tokens]
825+
if self.uses_mrope:
826+
positions = self.mrope_positions[:, :num_tokens]
827+
else:
828+
positions = self.positions[:num_tokens]
809829

810-
if get_pp_group().is_first_rank:
811-
intermediate_tensors = None
812-
else:
813-
if self.intermediate_tensors is None:
814-
self.intermediate_tensors = (
815-
self.model.make_empty_intermediate_tensors(
816-
batch_size=num_tokens,
817-
dtype=self.dtype,
818-
device=self.device))
819-
intermediate_tensors = IntermediateTensors({
820-
k: v[:num_tokens]
821-
for k, v in self.intermediate_tensors.items()
822-
})
823-
824-
with set_forward_context(None, self.vllm_config):
825-
hidden_states = model(input_ids=input_ids,
826-
positions=positions,
827-
intermediate_tensors=intermediate_tensors,
828-
inputs_embeds=inputs_embeds)
829-
return hidden_states
830+
if get_pp_group().is_first_rank:
831+
intermediate_tensors = None
832+
else:
833+
if self.intermediate_tensors is None:
834+
self.intermediate_tensors = (
835+
self.model.make_empty_intermediate_tensors(
836+
batch_size=self.max_num_tokens,
837+
dtype=self.model_config.dtype,
838+
device=self.device))
839+
intermediate_tensors = IntermediateTensors({
840+
k: v[:num_tokens]
841+
for k, v in self.intermediate_tensors.items()
842+
})
843+
844+
with set_forward_context(None, self.vllm_config):
845+
hidden_states = model(
846+
input_ids=input_ids,
847+
positions=positions.to(self.device),
848+
intermediate_tensors=intermediate_tensors,
849+
inputs_embeds=inputs_embeds)
850+
return hidden_states
830851

831852
def profile_run(self) -> None:
832853
# Profile with multimodal encoder & encoder cache.
@@ -875,7 +896,11 @@ def load_model(self) -> None:
875896
with DeviceMemoryProfiler() as m: # noqa: SIM117
876897
self.model = get_model(vllm_config=self.vllm_config)
877898
if self.lora_config:
878-
raise ValueError("LoRA model is not supported on NPU now.")
899+
self.model = self.load_lora_model(self.model,
900+
self.model_config,
901+
self.scheduler_config,
902+
self.lora_config,
903+
self.device)
879904
logger.info("Loading model weights took %.4f GB",
880905
m.consumed_memory / float(2**30))
881906

0 commit comments

Comments
 (0)