Skip to content

Commit 8a2c588

Browse files
maxdebayserrussellb
andcommitted
Support encoder-only models without KV-Cache
Add support for encoder models such as BERT which don't support a KV cache due to the non-causal attention. Since the KV Cache Spec is used to build the attention metadata for decoder models, this PR initializes the attention metadata builds for encoder-only models directly from the layers and adds a function to build the attention metadata. This PR combines elements of PRs vllm-project#21088 and vllm-project#19988 Summary of changes: **Flash Attention Backend:** - Implement encoder self-attention support without using KV cache **Scheduler:** - Disable chunked prefill for models without KV cache **GPU Model Runner:** - Implement encoder-only attention metadata building for self-attention Related to: - V0 deprecation: vllm-project#18571 - 2025 Q3 roadmap: vllm-project#20336 Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
1 parent 7ba34b1 commit 8a2c588

File tree

9 files changed

+357
-94
lines changed

9 files changed

+357
-94
lines changed

tests/entrypoints/openai/test_rerank.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer):
124124
invocation_output["results"]):
125125
assert rerank_result.keys() == invocations_result.keys()
126126
assert rerank_result["relevance_score"] == pytest.approx(
127-
invocations_result["relevance_score"], rel=0.01)
127+
invocations_result["relevance_score"], rel=0.05)

tests/models/language/pooling/test_embedding.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,9 @@ def v1(run_with_both_engines):
3939
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
4040
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
4141
# [Encoder-only]
42-
pytest.param(
43-
"BAAI/bge-base-en-v1.5",
44-
marks=[
45-
# CPU only supports V1
46-
pytest.mark.core_model,
47-
pytest.mark.skip_v1
48-
]),
49-
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
50-
marks=[pytest.mark.skip_v1]),
51-
pytest.param("intfloat/multilingual-e5-small",
52-
marks=[pytest.mark.skip_v1]),
42+
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
43+
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
44+
pytest.param("intfloat/multilingual-e5-small"),
5345
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
5446
marks=[pytest.mark.skip_v1]),
5547
# [Cross-Encoder]

tests/models/language/pooling/test_jina.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
]
2424

2525

26+
@pytest.fixture(autouse=True)
27+
def v1(run_with_both_engines):
28+
# Simple autouse wrapper to run both engines for each test
29+
# This can be promoted up to conftest.py to run for every
30+
# test in a package
31+
pass
32+
33+
2634
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
2735
def test_embed_models_mteb(hf_runner, vllm_runner,
2836
model_info: EmbedModelInfo) -> None:

vllm/engine/arg_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1670,7 +1670,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,
16701670

16711671
if (self.max_num_seqs is None
16721672
and usage_context in default_max_num_seqs):
1673-
self.max_num_seqs = default_max_num_seqs[usage_context]
1673+
self.max_num_seqs = min(default_max_num_seqs[usage_context],
1674+
self.max_num_batched_tokens or sys.maxsize)
16741675

16751676
logger.debug("Setting max_num_seqs to %d for %s usage context.",
16761677
self.max_num_seqs, use_context_value)

vllm/model_executor/models/bert.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.compilation.decorators import support_torch_compile
1313
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
1414
from vllm.distributed import get_tensor_model_parallel_world_size
15-
from vllm.forward_context import get_forward_context
1615
from vllm.model_executor.layers.activation import get_act_fn
1716
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1817
QKVParallelLinear,
@@ -59,7 +58,6 @@ def __init__(self, config: BertConfig):
5958
def forward(
6059
self,
6160
input_ids: torch.Tensor,
62-
seq_lens: torch.Tensor,
6361
position_ids: torch.Tensor,
6462
token_type_ids: Optional[torch.Tensor] = None,
6563
) -> torch.Tensor:
@@ -109,7 +107,6 @@ def forward(
109107
return pooled_output
110108

111109

112-
@support_torch_compile
113110
class BertEncoder(nn.Module):
114111

115112
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
@@ -327,6 +324,7 @@ def forward(self, hidden_states: torch.Tensor,
327324
return hidden_states
328325

329326

327+
@support_torch_compile
330328
class BertModel(nn.Module, SupportsQuant):
331329

332330
is_pooling_model = True
@@ -357,13 +355,9 @@ def forward(
357355
if inputs_embeds is not None:
358356
hidden_states = inputs_embeds
359357
else:
360-
attn_metadata = get_forward_context().attn_metadata
361-
assert hasattr(attn_metadata, "seq_lens_tensor")
362-
hidden_states = self.embeddings(
363-
input_ids=input_ids,
364-
seq_lens=attn_metadata.seq_lens_tensor,
365-
position_ids=position_ids,
366-
token_type_ids=token_type_ids)
358+
hidden_states = self.embeddings(input_ids=input_ids,
359+
position_ids=position_ids,
360+
token_type_ids=token_type_ids)
367361
return self.encoder(hidden_states)
368362

369363
def load_weights(self, weights: Iterable[tuple[str,
@@ -404,7 +398,7 @@ def load_weights(self, weights: Iterable[tuple[str,
404398
return loaded_params
405399

406400

407-
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
401+
class BertEmbeddingModel(nn.Module, SupportsQuant):
408402
"""A model that uses Bert to provide embedding functionalities.
409403
410404
This class encapsulates the BertModel and provides an interface for
@@ -429,11 +423,13 @@ def forward(
429423
self,
430424
input_ids: Optional[torch.Tensor],
431425
positions: torch.Tensor,
426+
token_type_ids: Optional[torch.Tensor] = None,
432427
intermediate_tensors: Optional[IntermediateTensors] = None,
433428
inputs_embeds: Optional[torch.Tensor] = None,
434429
) -> torch.Tensor:
435430
return self.model(input_ids=input_ids,
436431
position_ids=positions,
432+
token_type_ids=token_type_ids,
437433
inputs_embeds=inputs_embeds,
438434
intermediate_tensors=intermediate_tensors)
439435

vllm/model_executor/models/roberta.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import RobertaConfig
1010

1111
from vllm.config import VllmConfig
12+
from vllm.forward_context import get_forward_context
1213
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
1314
from vllm.model_executor.layers.vocab_parallel_embedding import (
1415
VocabParallelEmbedding)
@@ -50,39 +51,12 @@ def __init__(self, config: RobertaConfig):
5051
def forward(
5152
self,
5253
input_ids: torch.Tensor,
53-
seq_lens: torch.Tensor,
5454
position_ids: torch.Tensor,
5555
token_type_ids: Optional[torch.Tensor] = None,
5656
) -> torch.Tensor:
5757
input_shape = input_ids.size()
5858
inputs_embeds = self.word_embeddings(input_ids)
5959

60-
# Replace position ids because in RoBERTa models
61-
# they have to start at padding_idx + 1 and ignore
62-
# existing padding tokens
63-
# References:
64-
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
65-
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
66-
pos_list = []
67-
token_list = []
68-
offset = 0
69-
for seq_len in seq_lens:
70-
pos_list.append(position_ids[offset:offset + seq_len])
71-
token_list.append(input_ids[offset:offset + seq_len])
72-
offset += seq_len
73-
74-
new_pos_list = []
75-
for positions, tokens in zip(pos_list, token_list):
76-
# Verify assumption that incoming position are
77-
# always a sequence from 0 to N.
78-
expected_pos = torch.arange(positions.size()[0],
79-
dtype=torch.long,
80-
device=inputs_embeds.device)
81-
assert torch.equal(positions, expected_pos)
82-
new_pos_list.append(
83-
create_position_ids_from_input_ids(tokens, self.padding_idx))
84-
position_ids = torch.cat(new_pos_list)
85-
8660
# Position embeddings.
8761
position_embeddings = self.position_embeddings(position_ids)
8862
if token_type_ids is None:
@@ -124,6 +98,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
12498
_pooler: An instance of Pooler used for pooling operations.
12599
"""
126100

101+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
102+
super().__init__(vllm_config=vllm_config, prefix=prefix)
103+
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
104+
105+
def forward(
106+
self,
107+
input_ids: Optional[torch.Tensor],
108+
positions: torch.Tensor,
109+
token_type_ids: Optional[torch.Tensor] = None,
110+
intermediate_tensors: Optional[IntermediateTensors] = None,
111+
inputs_embeds: Optional[torch.Tensor] = None,
112+
) -> torch.Tensor:
113+
114+
# Fix Roberta positions here outside of the CUDA graph.
115+
# Because we need the to extract the sequences from
116+
# input_ids the control flow is data dependent.
117+
replace_roberta_positions(input_ids=input_ids,
118+
position_ids=positions,
119+
padding_idx=self.padding_idx)
120+
121+
return self.model(input_ids=input_ids,
122+
position_ids=positions,
123+
token_type_ids=token_type_ids,
124+
inputs_embeds=inputs_embeds,
125+
intermediate_tensors=intermediate_tensors)
126+
127127
def _build_model(self,
128128
vllm_config: VllmConfig,
129129
prefix: str = "") -> Union[BertModel, BertWithRope]:
@@ -180,6 +180,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
180180
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
181181
super().__init__()
182182
config = vllm_config.model_config.hf_config
183+
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
183184

184185
self.num_labels = config.num_labels
185186
self.roberta = BertModel(vllm_config=vllm_config,
@@ -206,6 +207,9 @@ def forward(
206207
inputs_embeds: Optional[torch.Tensor] = None,
207208
token_type_ids: Optional[torch.Tensor] = None,
208209
) -> torch.Tensor:
210+
replace_roberta_positions(input_ids=input_ids,
211+
position_ids=positions,
212+
padding_idx=self.padding_idx)
209213
return self.roberta(input_ids=input_ids,
210214
position_ids=positions,
211215
inputs_embeds=inputs_embeds,
@@ -235,3 +239,36 @@ def create_position_ids_from_input_ids(input_ids,
235239
past_key_values_length) * mask
236240

237241
return incremental_indices.long() + padding_idx
242+
243+
244+
def replace_roberta_positions(input_ids: torch.Tensor,
245+
position_ids: torch.Tensor,
246+
padding_idx: int) -> None:
247+
248+
seq_lens: Optional[torch.Tensor] = None
249+
attn_metadata = get_forward_context().attn_metadata
250+
if attn_metadata is not None: # can be None during warmup
251+
if isinstance(attn_metadata, dict):
252+
attn_metadata = next(iter(attn_metadata.values()))
253+
# TODO: remove "seq_lens_tensor" after V0 is removed
254+
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
255+
getattr(attn_metadata, "seq_lens", None))
256+
257+
if seq_lens is not None:
258+
assert isinstance(seq_lens, torch.Tensor)
259+
260+
# Replace position ids because in RoBERTa models
261+
# they have to start at padding_idx + 1 and ignore
262+
# existing padding tokens
263+
# References:
264+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
265+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
266+
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
267+
seq_lens.tolist())
268+
269+
offset = 0
270+
for tokens in token_list:
271+
length = tokens.shape[0]
272+
position_ids[offset:offset+length] = \
273+
create_position_ids_from_input_ids(tokens, padding_idx)
274+
offset = offset + length

0 commit comments

Comments
 (0)