From f50dd8754cb7f3be9766b0964494f69d1dab0dd8 Mon Sep 17 00:00:00 2001 From: jesse Date: Fri, 9 May 2025 16:59:44 +0800 Subject: [PATCH 1/9] [Feature] add V1 Engine LoRA support (vllm-project#801) What this PR does / why we need it? According to this RFC #396 and this #448, we pull request relavant code to support LoRA in v1 Engine Does this PR introduce any user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py Signed-off-by: jesse --- .github/workflows/vllm_ascend_test.yaml | 12 +-- tests/singlecard/test_transfomers_model.py | 67 ++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 91 ++++++++++++++-------- vllm_ascend/worker/worker_v1.py | 13 ++++ 4 files changed, 145 insertions(+), 38 deletions(-) create mode 100644 tests/singlecard/test_transfomers_model.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 7f8ac0e31..a8b93de1c 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -51,11 +51,11 @@ jobs: vllm_verison: [main, v0.8.5.post1] concurrency: group: > - ${{ - matrix.os == 'linux-arm64-npu-4' - && github.event.pull_request.number - && format('pr-{0}-limit-npu-4', github.event.pull_request.number) - || format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number) + ${{ + matrix.os == 'linux-arm64-npu-4' + && github.event.pull_request.number + && format('pr-{0}-limit-npu-4', github.event.pull_request.number) + || format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number) }} cancel-in-progress: false name: vLLM Ascend test @@ -112,6 +112,7 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_offline_inference.py + pytest -sv tests/singlecard/test_transfomers_model.py pytest -sv tests/ops pytest -sv tests/compile else @@ -126,6 +127,7 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_offline_inference.py + pytest -sv tests/singlecard/test_transfomers_model.py pytest -sv tests/ops else pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py diff --git a/tests/singlecard/test_transfomers_model.py b/tests/singlecard/test_transfomers_model.py new file mode 100644 index 000000000..814276f77 --- /dev/null +++ b/tests/singlecard/test_transfomers_model.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import vllm +from huggingface_hub import snapshot_download +from vllm.lora.request import LoRARequest + +from tests.conftest import VllmRunner + +MODEL_PATH = "ArthurZ/ilama-3.2-1B" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT DISTINCT Country FROM singer WHERE Age > 20", +] + + +@pytest.fixture(scope="session") +def ilama_lora_files(): + return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: + prompts = [ + PROMPT_TEMPLATE.format(query="How many singers do we have?"), + PROMPT_TEMPLATE.format( + query= + "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + query= + "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_ilama_lora(ilama_lora_files): + with VllmRunner(model_name=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_model_len=1024, + max_num_seqs=16) as vllm_model: + + output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + + output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 18037439c..1bafa5331 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -50,6 +50,7 @@ from vllm.v1.sample.sampler import Sampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -102,7 +103,7 @@ def graph_capture(device: torch.device): yield graph_capture_context -class NPUModelRunner: +class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config @@ -507,6 +508,10 @@ def _process_reqs( max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # Prepare positions req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) @@ -833,39 +838,55 @@ def _profile_multimodal(self) -> None: @torch.inference_mode() def _dummy_run(self, num_tokens: int) -> torch.Tensor: - model = self.model - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=num_tokens, - dtype=self.dtype, - device=self.device)) - intermediate_tensors = IntermediateTensors({ - k: v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - with set_forward_context(None, self.vllm_config): - hidden_states = model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - return hidden_states + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, + dtype=self.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + with set_forward_context(None, self.vllm_config): + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. @@ -914,7 +935,11 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: - raise ValueError("LoRA model is not supported on NPU now.") + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 2ba1973c7..ae6a59eb0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -31,6 +31,7 @@ set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput @@ -216,6 +217,18 @@ def profile(self, is_start: bool = True): else: self.profiler.stop() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> set[int]: + return self.model_runner.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1) From 907a3221982711701b96d2686eed77b5b7522d3e Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 09:12:17 +0800 Subject: [PATCH 2/9] [TEST] add test_ilama_lora_tp2 Signed-off-by: paulyu --- tests/multicard/test_ilama_lora_tp2.py | 21 +++++++++++++++++++ ...ransfomers_model.py => test_ilama_lora.py} | 0 2 files changed, 21 insertions(+) create mode 100644 tests/multicard/test_ilama_lora_tp2.py rename tests/singlecard/{test_transfomers_model.py => test_ilama_lora.py} (100%) diff --git a/tests/multicard/test_ilama_lora_tp2.py b/tests/multicard/test_ilama_lora_tp2.py new file mode 100644 index 000000000..45b162312 --- /dev/null +++ b/tests/multicard/test_ilama_lora_tp2.py @@ -0,0 +1,21 @@ +from tests.conftest import VllmRunner +from tests.singlecard.test_ilama_lora import MODEL_PATH, do_sample + + +def test_ilama_lora_tp2(ilama_lora_files): + with VllmRunner(model_name=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_model_len=1024, + max_num_seqs=16) as vllm_model: + output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1) + + with VllmRunner(model_name=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_model_len=1024, + max_num_seqs=16) as vllm_model: + output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) + + for i in range(len(output1)): + assert output1[i] == output2[i] diff --git a/tests/singlecard/test_transfomers_model.py b/tests/singlecard/test_ilama_lora.py similarity index 100% rename from tests/singlecard/test_transfomers_model.py rename to tests/singlecard/test_ilama_lora.py From 87975204359b9d647d8d80e05308a1e48f15d682 Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 09:25:00 +0800 Subject: [PATCH 3/9] [Bugfix] fix bug Signed-off-by: paulyu --- .github/workflows/vllm_ascend_test.yaml | 5 ++++- tests/multicard/test_ilama_lora_tp2.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index a8b93de1c..761c09562 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -112,11 +112,12 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_offline_inference.py - pytest -sv tests/singlecard/test_transfomers_model.py + pytest -sv tests/singlecard/test_lora_ilama.py pytest -sv tests/ops pytest -sv tests/compile else pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py + pytest -sv tests/singlecard/test_lora_ilama_tp2.py pytest -sv tests/ops pytest -sv tests/compile fi @@ -126,10 +127,12 @@ jobs: VLLM_USE_V1: 0 run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then + pytest -sv tests/singlecard/test_lora_ilama.py pytest -sv tests/singlecard/test_offline_inference.py pytest -sv tests/singlecard/test_transfomers_model.py pytest -sv tests/ops else + pytest -sv tests/singlecard/test_lora_ilama_tp2.py pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py pytest -sv tests/ops diff --git a/tests/multicard/test_ilama_lora_tp2.py b/tests/multicard/test_ilama_lora_tp2.py index 45b162312..e2951d64b 100644 --- a/tests/multicard/test_ilama_lora_tp2.py +++ b/tests/multicard/test_ilama_lora_tp2.py @@ -14,7 +14,8 @@ def test_ilama_lora_tp2(ilama_lora_files): enable_lora=True, max_loras=4, max_model_len=1024, - max_num_seqs=16) as vllm_model: + max_num_seqs=16, + tensor_parallel_size=2) as vllm_model: output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) for i in range(len(output1)): From 9eb8dd7344ca7c07f167e83223f09e953412627e Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 09:26:58 +0800 Subject: [PATCH 4/9] [Bugfix] fix bug Signed-off-by: paulyu --- .github/workflows/vllm_ascend_test.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 761c09562..3237c77c6 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -129,7 +129,6 @@ jobs: if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_lora_ilama.py pytest -sv tests/singlecard/test_offline_inference.py - pytest -sv tests/singlecard/test_transfomers_model.py pytest -sv tests/ops else pytest -sv tests/singlecard/test_lora_ilama_tp2.py From 59d920ce0664dfe2e490c081caf115bb7da95a20 Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 09:43:29 +0800 Subject: [PATCH 5/9] [Bugfix] fix bug Signed-off-by: paulyu --- .github/workflows/vllm_ascend_test.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 3237c77c6..eeef8f907 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -112,12 +112,12 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/test_offline_inference.py - pytest -sv tests/singlecard/test_lora_ilama.py + pytest -sv tests/singlecard/test_ilama_lora.py pytest -sv tests/ops pytest -sv tests/compile else pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py - pytest -sv tests/singlecard/test_lora_ilama_tp2.py + pytest -sv tests/singlecard/test_ilama_lora_tp2.py pytest -sv tests/ops pytest -sv tests/compile fi @@ -127,11 +127,11 @@ jobs: VLLM_USE_V1: 0 run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then - pytest -sv tests/singlecard/test_lora_ilama.py + pytest -sv tests/singlecard/test_ilama_lora.py pytest -sv tests/singlecard/test_offline_inference.py pytest -sv tests/ops else - pytest -sv tests/singlecard/test_lora_ilama_tp2.py + pytest -sv tests/singlecard/test_ilama_lora_tp2.py pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py pytest -sv tests/ops From 9625adb327d936f50fd2ac423ae721139eeae2d1 Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 14:57:59 +0800 Subject: [PATCH 6/9] [Bugfix] fix bug Signed-off-by: paulyu --- .github/workflows/vllm_ascend_test.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index eeef8f907..ba9f02674 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -117,7 +117,7 @@ jobs: pytest -sv tests/compile else pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py - pytest -sv tests/singlecard/test_ilama_lora_tp2.py + pytest -sv tests/multicard/test_ilama_lora_tp2.py pytest -sv tests/ops pytest -sv tests/compile fi @@ -131,7 +131,7 @@ jobs: pytest -sv tests/singlecard/test_offline_inference.py pytest -sv tests/ops else - pytest -sv tests/singlecard/test_ilama_lora_tp2.py + pytest -sv tests/multicard/test_ilama_lora_tp2.py pytest -sv -k "QwQ" tests/multicard/test_offline_inference_distributed.py pytest -sv -k "DeepSeek" tests/multicard/test_offline_inference_distributed.py pytest -sv tests/ops From 69650749240e9240d2b8b73cdffb8ec748e95e34 Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 16:50:51 +0800 Subject: [PATCH 7/9] [Bugfix] fix bug Signed-off-by: paulyu --- tests/conftest.py | 8 +++++++- tests/singlecard/test_ilama_lora.py | 7 ------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 78ffe8f4e..c9a62cb3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ import numpy as np import pytest import torch +from huggingface_hub import snapshot_download from PIL import Image from vllm import LLM, SamplingParams from vllm.config import TaskOption @@ -348,4 +349,9 @@ def vllm_runner(): @pytest.fixture(params=list(PROMPT_TEMPLATES.keys())) def prompt_template(request): - return PROMPT_TEMPLATES[request.param] \ No newline at end of file + return PROMPT_TEMPLATES[request.param] + + +@pytest.fixture(scope="session") +def ilama_lora_files(): + return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") \ No newline at end of file diff --git a/tests/singlecard/test_ilama_lora.py b/tests/singlecard/test_ilama_lora.py index 814276f77..2d93bceea 100644 --- a/tests/singlecard/test_ilama_lora.py +++ b/tests/singlecard/test_ilama_lora.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import pytest import vllm -from huggingface_hub import snapshot_download from vllm.lora.request import LoRARequest from tests.conftest import VllmRunner @@ -18,11 +16,6 @@ ] -@pytest.fixture(scope="session") -def ilama_lora_files(): - return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") - - def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), From 1dd4084a9aaae070b322d2b18d7c4c19938dbba3 Mon Sep 17 00:00:00 2001 From: paulyu Date: Mon, 19 May 2025 20:26:38 +0800 Subject: [PATCH 8/9] [Bugfix] fix bug Signed-off-by: paulyu --- tests/multicard/test_ilama_lora_tp2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/multicard/test_ilama_lora_tp2.py b/tests/multicard/test_ilama_lora_tp2.py index e2951d64b..2b1a48488 100644 --- a/tests/multicard/test_ilama_lora_tp2.py +++ b/tests/multicard/test_ilama_lora_tp2.py @@ -1,8 +1,11 @@ +import pytest + from tests.conftest import VllmRunner from tests.singlecard.test_ilama_lora import MODEL_PATH, do_sample -def test_ilama_lora_tp2(ilama_lora_files): +@pytest.mark.parametrize("distributed_executor_backend", ["mp"]) +def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files): with VllmRunner(model_name=MODEL_PATH, enable_lora=True, max_loras=4, @@ -15,7 +18,9 @@ def test_ilama_lora_tp2(ilama_lora_files): max_loras=4, max_model_len=1024, max_num_seqs=16, - tensor_parallel_size=2) as vllm_model: + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) for i in range(len(output1)): From 7125ff1f02cde4762a7fe0834ddd426c9b67354d Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Mon, 19 May 2025 23:13:12 +0800 Subject: [PATCH 9/9] [Bugfix] fig bug Signed-off-by: paulyu12 <507435917@qq.com> --- tests/multicard/test_ilama_lora_tp2.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/multicard/test_ilama_lora_tp2.py b/tests/multicard/test_ilama_lora_tp2.py index 2b1a48488..e61ce250c 100644 --- a/tests/multicard/test_ilama_lora_tp2.py +++ b/tests/multicard/test_ilama_lora_tp2.py @@ -1,18 +1,12 @@ import pytest from tests.conftest import VllmRunner -from tests.singlecard.test_ilama_lora import MODEL_PATH, do_sample +from tests.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, MODEL_PATH, + do_sample) @pytest.mark.parametrize("distributed_executor_backend", ["mp"]) def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files): - with VllmRunner(model_name=MODEL_PATH, - enable_lora=True, - max_loras=4, - max_model_len=1024, - max_num_seqs=16) as vllm_model: - output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1) - with VllmRunner(model_name=MODEL_PATH, enable_lora=True, max_loras=4, @@ -21,7 +15,7 @@ def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files): tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend ) as vllm_model: - output2 = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) + output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2) - for i in range(len(output1)): - assert output1[i] == output2[i] + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output[i] == EXPECTED_LORA_OUTPUT[i]