Skip to content

[Bugfix] Add use_cross_encoder flag to use correct activation in ClassifierPooler #20527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def _cross_encoding_score(

input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

pooling_params = PoolingParams()
pooling_params = PoolingParams(use_cross_encoder=True)

tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
Expand Down
10 changes: 6 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel):

# --8<-- [end:score-extra-params]

def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)


class RerankRequest(OpenAIBaseModel):
Expand All @@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel):

# --8<-- [end:rerank-extra-params]

def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)


class RerankDocument(BaseModel):
Expand Down
10 changes: 4 additions & 6 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators

logger = init_logger(__name__)
Expand All @@ -50,7 +48,7 @@ def __init__(

async def _embedding_score(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
Expand Down Expand Up @@ -141,7 +139,7 @@ async def _embedding_score(

async def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
tokenizer: AnyTokenizer,
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
Expand Down Expand Up @@ -190,7 +188,7 @@ async def _cross_encoding_score(
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

pooling_params = request.to_pooling_params()
pooling_params = request.to_pooling_params(use_cross_encoder=True)

for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
Expand Down
42 changes: 31 additions & 11 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_classification_activation_function,
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

Expand Down Expand Up @@ -388,15 +389,14 @@ def __init__(
self.classifier = classifier
self.pooler = pooler

if config.task == "score":
self.default_activation_function = \
get_cross_encoder_activation_function(config.hf_config)
elif config.task == "classify":
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax()
else:
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
self.classification_act_fn = get_classification_activation_function(
config.hf_config)
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config)

def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)

def get_prompt_lens(
self,
Expand Down Expand Up @@ -446,8 +446,28 @@ def forward(
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)

# shape: (batch_size, num_labels)
scores = self.default_activation_function(pooled_output)
if isinstance(pooling_metadata, V0PoolingMetadata):
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for _, pooling_param in pooling_metadata.seq_groups
]
else:
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for pooling_param in pooling_metadata.pooling_params
]

# shape of scores: (batch_size, num_labels)
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
Comment on lines +461 to +462
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
if len(set(use_cross_encoder_list)) == 1:

I think we can simplify the condition here.

act_fn = self._get_act_fn(use_cross_encoder_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(use_cross_encoder)(vecs)
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
pooled_output)
])

pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)
5 changes: 0 additions & 5 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
Expand Down Expand Up @@ -462,9 +460,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

self.default_activation_function = \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used by the current module so I removed them

get_cross_encoder_activation_function(config)

self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)

from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
Expand Down Expand Up @@ -178,9 +176,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config

self.default_activation_function = \
get_cross_encoder_activation_function(config)

self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
Expand Down
3 changes: 3 additions & 0 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class PoolingParams(
"""

dimensions: Optional[int] = None
use_cross_encoder: bool = False
additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data)

def verify(self, model_config: "ModelConfig") -> None:
Expand All @@ -54,6 +56,7 @@ def verify(self, model_config: "ModelConfig") -> None:
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})")

def __post_init__(self) -> None:
Expand Down
20 changes: 11 additions & 9 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,24 +866,26 @@ def try_get_generation_config(
return None


def get_cross_encoder_activation_function(config: PretrainedConfig):
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()


def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \
config.sentence_transformers:
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]

elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function

if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()

return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()


def try_get_safetensors_metadata(
Expand Down