Skip to content

Add support for encoder embedding models #19988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5dee54d
Add support for encoder embedding models
maxdebayser Jun 23, 2025
7eb9d28
Fix CUDA graphs for BERT models
maxdebayser Jul 1, 2025
67691e0
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 1, 2025
d3099a9
Fix cuda graph initialization of token type ids
maxdebayser Jul 1, 2025
613ff3b
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
20c41e4
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 2, 2025
ba86026
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 8, 2025
b4f5ead
Fix missing args
maxdebayser Jul 9, 2025
c4060d1
relax assertion
maxdebayser Jul 9, 2025
01d2a65
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 9, 2025
80930d8
fix missing arg
maxdebayser Jul 9, 2025
d881f0a
fix missing arg
maxdebayser Jul 10, 2025
90a25d0
remove model from unsupported list
maxdebayser Jul 10, 2025
6686550
fix missing arg
maxdebayser Jul 10, 2025
cc76777
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 10, 2025
136c9b3
fix tests
maxdebayser Jul 10, 2025
b232491
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 14, 2025
cf5e6b8
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
e19c738
fix tests
maxdebayser Jul 16, 2025
e255f30
fix tests
maxdebayser Jul 16, 2025
ee5950c
add missing arg
maxdebayser Jul 16, 2025
78a2e57
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
a5cfc84
add missing arg
maxdebayser Jul 16, 2025
63fd783
Merge branch 'upstream_main' into v1_embeddings_full
maxdebayser Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,13 @@ def v1(run_with_both_engines):
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5",
marks=[
pytest.mark.core_model, pytest.mark.cpu_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
marks=[pytest.mark.skip_v0]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing the skip mark for v0 tests, as it seems the models are now supported in both engines.

# [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
def test_models(
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
Expand Down
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
"The capital of Germany is Berlin.",
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


DTYPE = "half"


Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,4 +916,4 @@ def test_get_kv_cache_config():
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
])
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,

if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)

logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
24 changes: 10 additions & 14 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -28,7 +27,7 @@
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import WeightsMapper, maybe_prefix


Expand Down Expand Up @@ -57,7 +56,6 @@ def __init__(self, config: BertConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -97,7 +95,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -315,6 +312,7 @@ def forward(self, hidden_states: torch.Tensor,
return hidden_states


@support_torch_compile
class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}

Expand Down Expand Up @@ -342,13 +340,9 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)

def load_weights(self, weights: Iterable[tuple[str,
Expand Down Expand Up @@ -388,7 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -411,11 +405,13 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down Expand Up @@ -446,8 +442,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
softmax=False)


class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand Down
97 changes: 67 additions & 30 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand All @@ -22,7 +23,7 @@
get_cross_encoder_activation_function)

from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding


class RobertaEmbedding(nn.Module):
Expand Down Expand Up @@ -52,39 +53,13 @@ def __init__(self, config: RobertaConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:

input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len

new_pos_list = []
for positions, tokens in zip(pos_list, token_list):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -126,6 +101,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)

return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
Expand All @@ -150,8 +151,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
assert len(loaded), "Unable to load RobertaEmbeddingModel"


class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -177,6 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

self.default_activation_function = \
get_cross_encoder_activation_function(config)
Expand Down Expand Up @@ -221,6 +222,9 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -252,6 +256,39 @@ def create_position_ids_from_input_ids(input_ids,
return incremental_indices.long() + padding_idx


def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:

seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))

if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())

offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length


def roberta_task_weights_filter(
all_weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
Expand Down
30 changes: 25 additions & 5 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,13 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
Comment on lines +436 to 441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The NotImplementedError message is not very descriptive. It should indicate what types of attention are not implemented, and what attention types are supported.

        if attn_type not in [
                AttentionType.DECODER, AttentionType.ENCODER_ONLY
        ]:
            raise NotImplementedError("FlashAttentionImpl only supports DECODER and ENCODER_ONLY attention types.")

self.attn_type = attn_type
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
Expand Down Expand Up @@ -546,7 +548,7 @@ def forward(
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
causal=_get_causal_option(self.attn_type),
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
Expand Down Expand Up @@ -749,3 +751,21 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)


def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.

Args:
attn_type (AttentionType): The type of attention being evaluated

Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
1 change: 1 addition & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
attn_type=str(spec.attn_type),
)

if is_hybrid(kv_cache_spec):
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class NewRequestData:

req_id: str
prompt_token_ids: list[int]
token_type_ids: Optional[list[int]]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
Expand All @@ -42,6 +43,7 @@ def from_request(
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
token_type_ids=request.token_type_ids,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
Expand Down
Loading