Skip to content

Commit 5f8241c

Browse files
Potabkwangli
andauthored
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it? Change as little existing code as possible to add v1 pooling task's support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to vllm-ascend, Considering the frequent changes in upstream interfaces, in order to decouple, so i move it here ### How was this patch tested? CI passed with new added/existing test, and I have a simple test was first conducted locally which is adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like bellow: ```python import os import torch from vllm import LLM os.environ["VLLM_USE_MODELSCOPE"]="True" def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed") outputs = model.embed(input_texts) embeddings = torch.tensor([o.outputs.embedding for o in outputs]) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]] ``` --------- Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: wangli <858794774@qq.com> Co-authored-by: wangli <858794774@qq.com>
1 parent 790c810 commit 5f8241c

File tree

10 files changed

+1312
-43
lines changed

10 files changed

+1312
-43
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
- name: Run codespell check
8787
run: |
8888
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
89-
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn')
89+
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn')
9090
9191
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
9292
- name: Analysing the code with ruff
@@ -262,11 +262,13 @@ jobs:
262262
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
263263
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
264264
pytest -sv tests/e2e/singlecard/test_camem.py
265+
pytest -sv tests/e2e/singlecard/test_embedding.py
265266
pytest -sv tests/e2e/singlecard/ \
266267
--ignore=tests/e2e/singlecard/test_offline_inference.py \
267268
--ignore=tests/e2e/singlecard/test_ilama_lora.py \
268269
--ignore=tests/e2e/singlecard/test_guided_decoding.py \
269-
--ignore=tests/e2e/singlecard/test_camem.py
270+
--ignore=tests/e2e/singlecard/test_camem.py \
271+
--ignore=tests/e2e/singlecard/test_embedding.py
270272
271273
- name: Run e2e test on V0 engine
272274
if: ${{ github.event_name == 'schedule' }}
@@ -281,14 +283,16 @@ jobs:
281283
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
282284
pytest -sv tests/e2e/singlecard/test_camem.py
283285
pytest -sv tests/e2e/singlecard/test_prompt_embedding.py
286+
pytest -sv tests/e2e/singlecard/test_embedding.py
284287
pytest -sv tests/e2e/singlecard/ \
285288
--ignore=tests/e2e/singlecard/test_offline_inference.py \
286289
--ignore=tests/e2e/singlecard/test_ilama_lora.py \
287290
--ignore=tests/e2e/singlecard/test_guided_decoding.py \
288291
--ignore=tests/e2e/singlecard/test_camem.py \
289292
--ignore=tests/e2e/singlecard/test_prompt_embedding.py \
290293
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler.py \
291-
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py
294+
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py \
295+
--ignore=tests/e2e/singlecard/test_embedding.py
292296
293297
e2e-4-cards:
294298
needs: [e2e]

examples/offline_embed.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B
18+
#
19+
20+
import os
21+
22+
import torch
23+
from vllm import LLM
24+
25+
os.environ["VLLM_USE_MODELSCOPE"] = "True"
26+
27+
28+
def get_detailed_instruct(task_description: str, query: str) -> str:
29+
return f'Instruct: {task_description}\nQuery:{query}'
30+
31+
32+
# Each query must come with a one-sentence instruction that describes the task
33+
task = 'Given a web search query, retrieve relevant passages that answer the query'
34+
35+
queries = [
36+
get_detailed_instruct(task, 'What is the capital of China?'),
37+
get_detailed_instruct(task, 'Explain gravity')
38+
]
39+
# No need to add instruction for retrieval documents
40+
documents = [
41+
"The capital of China is Beijing.",
42+
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
43+
]
44+
input_texts = queries + documents
45+
46+
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
47+
48+
outputs = model.embed(input_texts)
49+
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
50+
# Calculate the similarity scores between the first two queries and the last two documents
51+
scores = (embeddings[:2] @ embeddings[2:].T)
52+
print(scores.tolist())
53+
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ xgrammar
1212
zmq
1313
types-psutil
1414
pytest-cov
15+
sentence_transformers

tests/conftest.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@
1919

2020
import contextlib
2121
import gc
22-
from typing import List, Optional, Tuple, TypeVar, Union
22+
from typing import Any, List, Optional, Tuple, TypeVar, Union
2323

2424
import numpy as np
2525
import pytest
2626
import torch
2727
from huggingface_hub import snapshot_download
2828
from PIL import Image
29+
from torch import nn
30+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
31+
BatchEncoding, BatchFeature)
32+
from transformers.models.auto.auto_factory import _BaseAutoModelClass
2933
from vllm import LLM, SamplingParams
30-
from vllm.config import TaskOption
34+
from vllm.config import TaskOption, _get_and_verify_dtype
3135
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
3236
from vllm.outputs import RequestOutput
3337
from vllm.sampling_params import BeamSearchParams
38+
from vllm.transformers_utils.utils import maybe_model_redirect
3439
from vllm.utils import is_list_of
3540

3641
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
@@ -45,6 +50,7 @@
4550
from vllm.distributed.parallel_state import ( # noqa E402
4651
destroy_distributed_environment, destroy_model_parallel)
4752

53+
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
4854
_M = TypeVar("_M")
4955

5056
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
@@ -364,3 +370,131 @@ def prompt_template(request):
364370
@pytest.fixture(scope="session")
365371
def ilama_lora_files():
366372
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
373+
374+
375+
class HfRunner:
376+
377+
def get_default_device(self):
378+
from vllm.platforms import current_platform
379+
380+
return ("cpu"
381+
if current_platform.is_cpu() else current_platform.device_type)
382+
383+
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
384+
if x is None or isinstance(x, (bool, )):
385+
return x
386+
387+
if device is None:
388+
device = self.device
389+
390+
if isinstance(x, dict):
391+
return {k: self.wrap_device(v, device) for k, v in x.items()}
392+
393+
if hasattr(x, "device") and x.device.type == device:
394+
return x
395+
396+
return x.to(device)
397+
398+
def __init__(
399+
self,
400+
model_name: str,
401+
dtype: str = "auto",
402+
*,
403+
model_kwargs: Optional[dict[str, Any]] = None,
404+
trust_remote_code: bool = True,
405+
is_sentence_transformer: bool = False,
406+
is_cross_encoder: bool = False,
407+
skip_tokenizer_init: bool = False,
408+
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
409+
) -> None:
410+
model_name = maybe_model_redirect(model_name)
411+
self.model_name = model_name
412+
413+
self.config = AutoConfig.from_pretrained(
414+
model_name,
415+
trust_remote_code=trust_remote_code,
416+
)
417+
self.device = self.get_default_device()
418+
self.dtype = torch_dtype = _get_and_verify_dtype(
419+
self.model_name,
420+
self.config,
421+
dtype=dtype,
422+
is_pooling_model=is_sentence_transformer or is_cross_encoder,
423+
)
424+
425+
model_kwargs = model_kwargs if model_kwargs is not None else {}
426+
model_kwargs.setdefault("torch_dtype", torch_dtype)
427+
428+
if is_sentence_transformer:
429+
# Lazy init required for AMD CI
430+
from sentence_transformers import SentenceTransformer
431+
432+
self.model = SentenceTransformer(
433+
model_name,
434+
device=self.device,
435+
model_kwargs=model_kwargs,
436+
trust_remote_code=trust_remote_code,
437+
)
438+
elif is_cross_encoder:
439+
# Lazy init required for AMD CI
440+
from sentence_transformers import CrossEncoder
441+
442+
self.model = CrossEncoder(
443+
model_name,
444+
device=self.device,
445+
automodel_args=model_kwargs,
446+
trust_remote_code=trust_remote_code,
447+
)
448+
else:
449+
model = auto_cls.from_pretrained(
450+
model_name,
451+
trust_remote_code=trust_remote_code,
452+
**model_kwargs,
453+
)
454+
455+
# in case some unquantized custom models are not in same dtype
456+
if (getattr(model, "quantization_method", None) is None
457+
and any(p.dtype != self.dtype
458+
for p in model.parameters())):
459+
model = model.to(dtype=self.dtype)
460+
461+
if (getattr(model, "quantization_method", None) != "bitsandbytes"
462+
and len({p.device
463+
for p in model.parameters()}) < 2):
464+
model = model.to(device=self.device)
465+
466+
self.model = model
467+
468+
if not skip_tokenizer_init:
469+
self.tokenizer = AutoTokenizer.from_pretrained(
470+
model_name,
471+
torch_dtype=torch_dtype,
472+
trust_remote_code=trust_remote_code,
473+
)
474+
475+
# don't put this import at the top level
476+
# it will call torch.cuda.device_count()
477+
from transformers import AutoProcessor # noqa: F401
478+
self.processor = AutoProcessor.from_pretrained(
479+
model_name,
480+
torch_dtype=torch_dtype,
481+
trust_remote_code=trust_remote_code,
482+
)
483+
if skip_tokenizer_init:
484+
self.tokenizer = self.processor.tokenizer
485+
486+
def encode(self, prompts: list[str], *args,
487+
**kwargs) -> list[list[torch.Tensor]]:
488+
return self.model.encode(prompts, *args, **kwargs)
489+
490+
def __enter__(self):
491+
return self
492+
493+
def __exit__(self, exc_type, exc_value, traceback):
494+
del self.model
495+
cleanup_dist_env_and_memory()
496+
497+
498+
@pytest.fixture(scope="session")
499+
def hf_runner():
500+
return HfRunner
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
18+
#
19+
from collections.abc import Sequence
20+
from typing import Optional
21+
22+
import pytest
23+
from modelscope import snapshot_download # type: ignore[import-untyped]
24+
25+
from tests.conftest import HfRunner
26+
from tests.utils import check_embeddings_close, matryoshka_fy
27+
from vllm_ascend.utils import vllm_version_is
28+
29+
30+
def run_embedding_correctness_test(
31+
hf_model: "HfRunner",
32+
inputs: list[str],
33+
vllm_outputs: Sequence[list[float]],
34+
dimensions: Optional[int] = None,
35+
):
36+
hf_outputs = hf_model.encode(inputs)
37+
if dimensions:
38+
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
39+
40+
check_embeddings_close(
41+
embeddings_0_lst=hf_outputs,
42+
embeddings_1_lst=vllm_outputs,
43+
name_0="hf",
44+
name_1="vllm",
45+
tol=1e-2,
46+
)
47+
48+
49+
# dummy to avoid pytest collect nothing and exit code 5
50+
def test_dummy():
51+
assert True
52+
53+
54+
@pytest.mark.skipif(vllm_version_is("0.9.1"),
55+
reason="vLLM 0.9.1 does not support embed task for v1")
56+
def test_embed_models_correctness(hf_runner, vllm_runner):
57+
queries = ['What is the capital of China?', 'Explain gravity']
58+
59+
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
60+
with vllm_runner(
61+
model_name,
62+
task="embed",
63+
enforce_eager=True,
64+
) as vllm_model:
65+
vllm_outputs = vllm_model.encode(queries)
66+
67+
with hf_runner(
68+
model_name,
69+
dtype="float32",
70+
is_sentence_transformer=True,
71+
) as hf_model:
72+
run_embedding_correctness_test(hf_model, queries, vllm_outputs)

0 commit comments

Comments
 (0)