Skip to content

Commit 8fd0814

Browse files
maxdebayser22quinn
authored andcommitted
Support embedding models in V1 (vllm-project#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent e097335 commit 8fd0814

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+895
-287
lines changed

examples/offline_inference/basic/embed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def parse_args():
1212
parser = EngineArgs.add_cli_args(parser)
1313
# Set example specific arguments
1414
parser.set_defaults(
15-
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
15+
model="intfloat/e5-mistral-7b-instruct",
16+
task="embed",
17+
enforce_eager=True,
18+
max_model_len=1024,
1619
)
1720
return parser.parse_args()
1821

examples/offline_inference/vision_language_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
9494
engine_args = EngineArgs(
9595
model="TIGER-Lab/VLM2Vec-Full",
9696
task="embed",
97+
max_model_len=4096,
9798
trust_remote_code=True,
9899
mm_processor_kwargs={"num_crops": 4},
99100
limit_mm_per_prompt={"image": 1},

tests/compile/test_basic_correctness.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TestSetting:
3131
# basic llama model
3232
TestSetting(
3333
model="meta-llama/Llama-3.2-1B-Instruct",
34-
model_args=[],
34+
model_args=["--max-model-len", "2048"],
3535
pp_size=2,
3636
tp_size=2,
3737
attn_backend="FLASHINFER",
@@ -41,7 +41,7 @@ class TestSetting:
4141
# llama model with quantization
4242
TestSetting(
4343
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
44-
model_args=["--quantization", "gptq"],
44+
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
4545
pp_size=1,
4646
tp_size=1,
4747
attn_backend="FLASH_ATTN",
@@ -51,7 +51,7 @@ class TestSetting:
5151
# MoE model
5252
TestSetting(
5353
model="ibm/PowerMoE-3b",
54-
model_args=[],
54+
model_args=["--max-model-len", "2048"],
5555
pp_size=1,
5656
tp_size=2,
5757
attn_backend="FLASH_ATTN",
@@ -61,23 +61,27 @@ class TestSetting:
6161
# embedding model
6262
TestSetting(
6363
model="BAAI/bge-multilingual-gemma2",
64-
model_args=["--task", "embed", "--dtype", "bfloat16"],
64+
model_args=[
65+
"--task", "embed", "--dtype", "bfloat16", "--max-model-len",
66+
"2048"
67+
],
6568
pp_size=1,
6669
tp_size=1,
6770
attn_backend="FLASH_ATTN",
6871
method="encode",
6972
fullgraph=True,
7073
),
71-
# encoder-based embedding model (BERT)
72-
TestSetting(
73-
model="BAAI/bge-base-en-v1.5",
74-
model_args=["--task", "embed"],
75-
pp_size=1,
76-
tp_size=1,
77-
attn_backend="XFORMERS",
78-
method="encode",
79-
fullgraph=True,
80-
),
74+
# TODO: bert models are not supported in V1 yet
75+
# # encoder-based embedding model (BERT)
76+
# TestSetting(
77+
# model="BAAI/bge-base-en-v1.5",
78+
# model_args=["--task", "embed"],
79+
# pp_size=1,
80+
# tp_size=1,
81+
# attn_backend="XFORMERS",
82+
# method="encode",
83+
# fullgraph=True,
84+
# ),
8185
# vision language model
8286
TestSetting(
8387
model="microsoft/Phi-3.5-vision-instruct",

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,16 @@ def run_with_both_engines(request, monkeypatch):
145145
# Automatically runs tests twice, once with V1 and once without
146146
use_v1 = request.param
147147
# Tests decorated with `@skip_v1` are only run without v1
148+
skip_v0 = request.node.get_closest_marker("skip_v0")
148149
skip_v1 = request.node.get_closest_marker("skip_v1")
149150

150151
if use_v1:
151152
if skip_v1:
152153
pytest.skip("Skipping test on vllm V1")
153154
monkeypatch.setenv('VLLM_USE_V1', '1')
154155
else:
156+
if skip_v0:
157+
pytest.skip("Skipping test on vllm V0")
155158
monkeypatch.setenv('VLLM_USE_V1', '0')
156159

157160
yield

tests/entrypoints/llm/test_encode.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from vllm import LLM, PoolingParams, PoolingRequestOutput
99
from vllm.distributed import cleanup_dist_env_and_memory
1010

11+
from ...models.utils import check_embeddings_close
12+
1113
MODEL_NAME = "intfloat/multilingual-e5-small"
1214

1315
PROMPTS = [
@@ -27,6 +29,14 @@
2729
]
2830

2931

32+
@pytest.fixture(autouse=True)
33+
def v1(run_with_both_engines):
34+
# Simple autouse wrapper to run both engines for each test
35+
# This can be promoted up to conftest.py to run for every
36+
# test in a package
37+
pass
38+
39+
3040
@pytest.fixture(scope="module")
3141
def llm():
3242
# pytest caches the fixture so we use weakref.proxy to
@@ -46,9 +56,15 @@ def llm():
4656
cleanup_dist_env_and_memory()
4757

4858

49-
def assert_outputs_equal(o1: list[PoolingRequestOutput],
59+
def assert_outputs_match(o1: list[PoolingRequestOutput],
5060
o2: list[PoolingRequestOutput]):
51-
assert [o.outputs for o in o1] == [o.outputs for o in o2]
61+
check_embeddings_close(
62+
embeddings_0_lst=[o.outputs.data for o in o1],
63+
embeddings_1_lst=[o.outputs.data for o in o2],
64+
name_0="hf",
65+
name_1="vllm",
66+
tol=1e-2,
67+
)
5268

5369

5470
@pytest.mark.skip_global_cleanup
@@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
6379

6480
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
6581
pooling_params=pooling_params)
66-
assert_outputs_equal(v1_output, v2_output)
82+
assert_outputs_match(v1_output, v2_output)
6783

6884

6985
@pytest.mark.skip_global_cleanup
@@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
8096
} for p in TOKEN_IDS],
8197
pooling_params=pooling_params,
8298
)
83-
assert_outputs_equal(v1_output, v2_output)
99+
assert_outputs_match(v1_output, v2_output)
84100

85101

86102
@pytest.mark.skip_global_cleanup

tests/entrypoints/openai/test_embedding.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
DTYPE = "bfloat16"
2222

2323

24+
@pytest.fixture(autouse=True)
25+
def v1(run_with_both_engines):
26+
# Simple autouse wrapper to run both engines for each test
27+
# This can be promoted up to conftest.py to run for every
28+
# test in a package
29+
pass
30+
31+
2432
@pytest.fixture(scope="module")
2533
def server():
2634
args = [

tests/entrypoints/openai/test_pooling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import requests
99

10+
from tests.models.utils import check_embeddings_close
1011
from vllm.entrypoints.openai.protocol import PoolingResponse
1112
from vllm.transformers_utils.tokenizer import get_tokenizer
1213

@@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
223224
np.frombuffer(base64.b64decode(data.data),
224225
dtype="float32").tolist())
225226

226-
assert responses_float.data[0].data == decoded_responses_base64_data[0]
227-
assert responses_float.data[1].data == decoded_responses_base64_data[1]
227+
check_embeddings_close(
228+
embeddings_0_lst=[d.data for d in responses_float.data],
229+
embeddings_1_lst=decoded_responses_base64_data,
230+
name_0="float32",
231+
name_1="base64")
228232

229233
# Default response is float32 decoded from base64 by OpenAI Client
230234
default_response = requests.post(
@@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
237241
default_response.raise_for_status()
238242
responses_default = PoolingResponse.model_validate(default_response.json())
239243

240-
assert responses_float.data[0].data == responses_default.data[0].data
241-
assert responses_float.data[1].data == responses_default.data[1].data
244+
check_embeddings_close(
245+
embeddings_0_lst=[d.data for d in responses_default.data],
246+
embeddings_1_lst=[d.data for d in responses_default.data],
247+
name_0="float32",
248+
name_1="base64")

tests/entrypoints/openai/test_rerank.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
DTYPE = "bfloat16"
1313

1414

15+
@pytest.fixture(autouse=True)
16+
def v1(run_with_both_engines):
17+
# Simple autouse wrapper to run both engines for each test
18+
# This can be promoted up to conftest.py to run for every
19+
# test in a package
20+
pass
21+
22+
1523
@pytest.fixture(scope="module")
1624
def server():
1725
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]

tests/entrypoints/openai/test_score.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
from ...utils import RemoteOpenAIServer
1313

14+
15+
@pytest.fixture(autouse=True)
16+
def v1(run_with_both_engines):
17+
# Simple autouse wrapper to run both engines for each test
18+
# This can be promoted up to conftest.py to run for every
19+
# test in a package
20+
pass
21+
22+
1423
MODELS = [
1524
{
1625
"name": "BAAI/bge-reranker-v2-m3",

tests/models/language/pooling/test_classification.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66

77
from vllm.platforms import current_platform
88

9+
# TODO: enable when float32 is supported by V1
10+
# @pytest.fixture(autouse=True)
11+
# def v1(run_with_both_engines):
12+
# # Simple autouse wrapper to run both engines for each test
13+
# # This can be promoted up to conftest.py to run for every
14+
# # test in a package
15+
# pass
16+
917

1018
@pytest.mark.parametrize(
1119
"model",
@@ -29,7 +37,7 @@ def test_models(
2937
# switch to use ROCm CK FA backend
3038
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
3139

32-
with vllm_runner(model, dtype=dtype) as vllm_model:
40+
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
3341
vllm_outputs = vllm_model.classify(example_prompts)
3442

3543
with hf_runner(model,

0 commit comments

Comments
 (0)