Skip to content

Support embedding models in V1 #16188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 98 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
f36c4f9
Remove guardrails that prevent V1 from trying to run embedding models
maxdebayser Mar 24, 2025
acf4638
hack v1 flash_attn to support encoder_only
maxdebayser Apr 3, 2025
b13bbc0
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 3, 2025
8debea0
Revert changes to disable kv caching for encoder-only models
maxdebayser Apr 3, 2025
8d97b9c
Add pooling support in v1
maxdebayser Apr 5, 2025
d60b22b
First end-to-end working version of Bert embeddings in V1
maxdebayser Apr 7, 2025
6bebbb8
Support warmup for pooling models in V1
maxdebayser Apr 7, 2025
6dafd71
address review comments
maxdebayser Apr 7, 2025
e2724a2
address review comments
maxdebayser Apr 7, 2025
56ff6cd
remove debug prints
maxdebayser Apr 7, 2025
fc57edd
address review comments
maxdebayser Apr 7, 2025
64a0e62
Fix cross encoder models in V1 and enable tests for pooling models
maxdebayser Apr 8, 2025
4014d41
address review comments
maxdebayser Apr 8, 2025
87a95a8
Merge branch 'main' into v1_embeddings
maxdebayser Apr 8, 2025
902c129
address review comments
maxdebayser Apr 8, 2025
2c68855
re-enable large embedding models
maxdebayser Apr 8, 2025
8afd8f5
address review comments
maxdebayser Apr 8, 2025
7762976
Merge branch 'main' into v1_embeddings
maxdebayser Apr 8, 2025
d7537ae
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 8, 2025
a9e7747
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 9, 2025
17520bd
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 14, 2025
90c611a
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 15, 2025
dec2441
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 17, 2025
a5e83f4
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 23, 2025
187f69b
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 24, 2025
69a0332
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 29, 2025
a9f1721
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 29, 2025
4b066a3
fix merge problems
maxdebayser Apr 30, 2025
43a26dc
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 30, 2025
ca34513
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Apr 30, 2025
bf3033d
Fix missing qwen embedding model param
maxdebayser Apr 30, 2025
67bf727
Make pooling params reach the pooling in V1
maxdebayser May 1, 2025
93b6361
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 1, 2025
d916b88
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 10, 2025
bad4211
fix merge problems
maxdebayser May 10, 2025
35d9bd9
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 11, 2025
dcc6100
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 12, 2025
a4f85b5
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 13, 2025
a5f328a
Merge branch 'upstream_main' into v1_embeddings
maxdebayser May 15, 2025
7c5be88
fix merge problem
maxdebayser May 15, 2025
29b75c9
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 4, 2025
6aa204c
backport changes from the other PR
maxdebayser Jun 4, 2025
e81470c
fix merge errors
maxdebayser Jun 4, 2025
20e7140
address review comments
maxdebayser Jun 4, 2025
6bc1e3d
address review comments
maxdebayser Jun 4, 2025
22825bd
simplify PR
maxdebayser Jun 4, 2025
c889b2e
fix mistake
maxdebayser Jun 4, 2025
24462e4
workaround qwen model test issue
maxdebayser Jun 6, 2025
b5f21f2
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 6, 2025
79d1b95
revert unecessary change
maxdebayser Jun 6, 2025
b3a0491
remove duplicated code
maxdebayser Jun 6, 2025
b4ab556
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 6, 2025
1a82e56
remove encoder model support to simplify PR
maxdebayser Jun 7, 2025
a66801b
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 9, 2025
660dd9c
fix several tests
maxdebayser Jun 9, 2025
808c996
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 9, 2025
cdd70c9
Fix test
maxdebayser Jun 9, 2025
0832115
disable bert test
maxdebayser Jun 9, 2025
10bbf74
fix tests
maxdebayser Jun 9, 2025
ee892aa
limit context length to fit test GPU
maxdebayser Jun 9, 2025
2e12eba
limit context length to fit test GPU
maxdebayser Jun 9, 2025
14fcf24
fix test
maxdebayser Jun 10, 2025
0624435
fix test
maxdebayser Jun 10, 2025
706fdb2
Merge branch 'main' into v1_embeddings
22quinn Jun 10, 2025
051f6d4
Fix _construct_cached_request_state
22quinn Jun 10, 2025
214cf06
Fix v1 tests
22quinn Jun 10, 2025
8193bd0
Merge pull request #1 from 22quinn/v1_embeddings
maxdebayser Jun 10, 2025
65b8377
fix test
maxdebayser Jun 10, 2025
33d7f74
Merge branch 'v1_embeddings' of github.com:maxdebayser/vllm into v1_e…
maxdebayser Jun 10, 2025
4ee822a
reduce max_model_len to fit in test gpu
maxdebayser Jun 10, 2025
7242731
fix test
maxdebayser Jun 10, 2025
a4f460b
fix test
maxdebayser Jun 10, 2025
35ca640
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 12, 2025
17f6177
fix test
maxdebayser Jun 12, 2025
3f0d42e
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 12, 2025
74d73cc
use torch.split
maxdebayser Jun 12, 2025
e6a66dc
enable cuda graphs
maxdebayser Jun 12, 2025
4cca774
fix unecessary config.py changes
maxdebayser Jun 12, 2025
8ef1982
fix error message
maxdebayser Jun 12, 2025
28d00d1
remove unused import
maxdebayser Jun 12, 2025
e634f60
fix docstring
maxdebayser Jun 12, 2025
053475c
revert unnecessary code changes
maxdebayser Jun 12, 2025
6228f64
remove debug prints
maxdebayser Jun 12, 2025
42c802a
fix refactoring bug
maxdebayser Jun 12, 2025
f771a19
fix refactoring bug
maxdebayser Jun 12, 2025
02c47ad
Fix default chunked prefill for pooling models
maxdebayser Jun 13, 2025
1fd252c
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 13, 2025
c5c0d97
Revert handling of case that can never happen
maxdebayser Jun 13, 2025
acfc9cc
fix small bug
maxdebayser Jun 13, 2025
225b808
fix small bugs
maxdebayser Jun 13, 2025
2b86c13
fix silly mistake
maxdebayser Jun 13, 2025
2983252
reduce memory usage for small ci gpus
maxdebayser Jun 13, 2025
58c556d
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 13, 2025
878d56a
enable chunked prefill by default for models that support it
maxdebayser Jun 14, 2025
2db273f
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 14, 2025
114af27
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 16, 2025
bc0219d
address review comments
maxdebayser Jun 16, 2025
221f013
Merge branch 'upstream_main' into v1_embeddings
maxdebayser Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import itertools

import torch
import triton
from weight_shapes import WEIGHT_SHAPES

from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from vllm.triton_utils import triton


@triton.testing.perf_report(
Expand Down
24 changes: 20 additions & 4 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory

from ...models.utils import check_embeddings_close

MODEL_NAME = "intfloat/multilingual-e5-small"

PROMPTS = [
Expand All @@ -27,6 +29,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
Expand All @@ -46,9 +56,15 @@ def llm():
cleanup_dist_env_and_memory()


def assert_outputs_equal(o1: list[PoolingRequestOutput],
def assert_outputs_match(o1: list[PoolingRequestOutput],
o2: list[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
check_embeddings_close(
embeddings_0_lst=[o.outputs.data for o in o1],
embeddings_1_lst=[o.outputs.data for o in o2],
name_0="hf",
name_1="vllm",
tol=1e-2,
)


@pytest.mark.skip_global_cleanup
Expand All @@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,

v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
assert_outputs_match(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
Expand All @@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
} for p in TOKEN_IDS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
assert_outputs_match(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
Expand Down
8 changes: 8 additions & 0 deletions tests/entrypoints/openai/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
DTYPE = "bfloat16"


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.fixture(scope="module")
def server():
args = [
Expand Down
8 changes: 8 additions & 0 deletions tests/entrypoints/openai/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
DTYPE = "bfloat16"


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.fixture(scope="module")
def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
Expand Down
9 changes: 9 additions & 0 deletions tests/entrypoints/openai/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

from ...utils import RemoteOpenAIServer


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


MODELS = [
{
"name": "BAAI/bge-reranker-v2-m3",
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

from vllm.platforms import current_platform

# TODO: enable when float32 is supported by V1
# @pytest.fixture(autouse=True)
# def v1(run_with_both_engines):
# # Simple autouse wrapper to run both engines for each test
# # This can be promoted up to conftest.py to run for every
# # test in a package
# pass


@pytest.mark.parametrize(
"model",
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from ...utils import check_embeddings_close


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize(
"model",
[
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.fixture(scope="module", params=SCORING_MODELS)
def model_name(request):
yield request.param
Expand Down
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"The capital of Germany is Berlin.",
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


DTYPE = "half"


Expand Down
38 changes: 36 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)
logger.warning("CUDA graph is not supported for pooling yet, "
"fallback to the eager mode.")
self.enforce_eager = True

pooler_config = self.override_pooler_config or PoolerConfig()

Expand Down Expand Up @@ -4439,14 +4442,45 @@ def __post_init__(self):
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

disable_cascade_reasons: list[str] = []

if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
disable_cascade_reasons.append(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
self.cache_config.enable_prefix_caching = False

disable_chunked_prefill_reasons: list[str] = []

if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")

disable_cascade_reasons.append(
"Loaded model for pooling; disabling cascade attention.")

if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.long_prefill_token_threshold = 0
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False

if disable_cascade_reasons:
for reason in disable_cascade_reasons:
logger.info(reason)
self.model_config.disable_cascade_attn = True

if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching):
Expand Down
6 changes: 0 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,12 +1336,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
Comment on lines -1352 to -1356
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to enable all tasks here?

TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
                     "score", "reward", "transcription"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll double check the "transcription" task, but the others yes. Is this causing a problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it just caused a conflcit in my branch where I had enabled "transcription" and thought maybe it had enabled more than you intended. It's fine if you mean to!

I'm not sure about transcription, either. I know it wouldn't work with whisper, but that'll still get blocked because the model is marked as v0-only. Since all models should have the V0-only marker where needed, this check probably isn't necessary.


# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def score(
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()
tokenizer = self.get_tokenizer()

def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
Expand Down
Loading