Skip to content

Support encoder-only models without KV-Cache #21270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import argparse
import datetime
import os
import re
from typing import Union

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
Expand Down
14 changes: 3 additions & 11 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,9 @@ def v1(run_with_both_engines):
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param(
"BAAI/bge-base-en-v1.5",
marks=[
# CPU only supports V1
pytest.mark.core_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder]
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def create_common_attn_metadata(
max_query_len=max_query_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)


Expand Down
3 changes: 1 addition & 2 deletions tests/v1/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re

import pytest
import regex as re
import requests
import torch

Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,

if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)

logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
18 changes: 7 additions & 11 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -119,7 +117,6 @@ def forward(
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
return hidden_states


@support_torch_compile
class BertModel(nn.Module, SupportsQuant):

is_pooling_model = True
Expand Down Expand Up @@ -368,13 +366,9 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)

def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Expand Down Expand Up @@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -474,11 +468,13 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down
85 changes: 64 additions & 21 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -51,33 +52,12 @@ def __init__(self, config: RobertaConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)

return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
Expand Down Expand Up @@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
Expand Down Expand Up @@ -216,6 +223,9 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask

return incremental_indices.long() + padding_idx


def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:

seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))

if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())

offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length
Loading