Skip to content

Commit bd70ce8

Browse files
authored
[CI] Add qwen2.5-vl test (#643)
### What this PR does / why we need it? Part of #499 Add qwen2.5-vl test on single npu, v1 engine is excluded because qwen2.5-vl has some problems with v1 now, at the same time, this test can also make #639 more credible Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent a9c6b52 commit bd70ce8

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.sampling_params import BeamSearchParams
3232
from vllm.utils import is_list_of
3333

34-
from tests.model_utils import (TokensTextLogprobs,
34+
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
3535
TokensTextLogprobsPromptLogprobs)
3636
# TODO: remove this part after the patch merged into vllm, if
3737
# we not explicitly patch here, some of them might be effectiveless
@@ -344,3 +344,8 @@ def __exit__(self, exc_type, exc_value, traceback):
344344
@pytest.fixture(scope="session")
345345
def vllm_runner():
346346
return VllmRunner
347+
348+
349+
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
350+
def prompt_template(request):
351+
return PROMPT_TEMPLATES[request.param]

tests/model_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#
1919

2020
import warnings
21-
from typing import Dict, List, Optional, Sequence, Tuple, Union
21+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
2222

2323
import torch
2424
from vllm.config import ModelConfig, TaskOption
@@ -301,3 +301,16 @@ def build_model_context(model_name: str,
301301
limit_mm_per_prompt=limit_mm_per_prompt,
302302
)
303303
return InputContext(model_config)
304+
305+
306+
def qwen_prompt(questions: List[str]) -> List[str]:
307+
placeholder = "<|image_pad|>"
308+
return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
309+
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
310+
f"{q}<|im_end|>\n<|im_start|>assistant\n") for q in questions]
311+
312+
313+
# Map of prompt templates for different models.
314+
PROMPT_TEMPLATES: dict[str, Callable] = {
315+
"qwen2.5vl": qwen_prompt,
316+
}

tests/singlecard/test_offline_inference.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pytest
2626
import vllm # noqa: F401
27+
from vllm.assets.image import ImageAsset
2728

2829
import vllm_ascend # noqa: F401
2930
from tests.conftest import VllmRunner
@@ -32,6 +33,7 @@
3233
"Qwen/Qwen2.5-0.5B-Instruct",
3334
"vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8",
3435
]
36+
MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
3537
os.environ["VLLM_USE_MODELSCOPE"] = "True"
3638
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3739

@@ -55,6 +57,32 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5557
vllm_model.generate_greedy(example_prompts, max_tokens)
5658

5759

60+
@pytest.mark.parametrize("model", MULTIMODALITY_MODELS)
61+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
62+
reason="qwen2.5_vl is not supported on v1")
63+
def test_multimodal(model, prompt_template, vllm_runner):
64+
image = ImageAsset("cherry_blossom") \
65+
.pil_image.convert("RGB")
66+
img_questions = [
67+
"What is the content of this image?",
68+
"Describe the content of this image in detail.",
69+
"What's in the image?",
70+
"Where is this image taken?",
71+
]
72+
images = [image] * len(img_questions)
73+
prompts = prompt_template(img_questions)
74+
with vllm_runner(model,
75+
max_model_len=4096,
76+
mm_processor_kwargs={
77+
"min_pixels": 28 * 28,
78+
"max_pixels": 1280 * 28 * 28,
79+
"fps": 1,
80+
}) as vllm_model:
81+
vllm_model.generate_greedy(prompts=prompts,
82+
images=images,
83+
max_tokens=64)
84+
85+
5886
if __name__ == "__main__":
5987
import pytest
6088
pytest.main([__file__])

0 commit comments

Comments
 (0)