From cff74b0ef213429cf001dba8f18f30ab17516ca5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Jul 2025 16:39:58 +0000 Subject: [PATCH 1/4] [Core] Add `use_cross_encoder` flag to use correct activation function in `ClassifierPooler` Signed-off-by: DarkLight1337 --- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/pooler.py | 43 ++++++++++++++++++++------- vllm/model_executor/models/bert.py | 5 ---- vllm/model_executor/models/roberta.py | 5 ---- vllm/pooling_params.py | 3 ++ vllm/transformers_utils/config.py | 20 +++++++------ 6 files changed, 47 insertions(+), 31 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 68f126c04283..6357c2a37c8f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1204,7 +1204,7 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] - pooling_params = PoolingParams() + pooling_params = PoolingParams(use_cross_encoder=True) tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 48adcc5fef84..81ea6ecddbaa 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -15,6 +15,7 @@ from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.transformers_utils.config import ( + get_classification_activation_function, get_cross_encoder_activation_function) from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata @@ -388,15 +389,14 @@ def __init__( self.classifier = classifier self.pooler = pooler - if config.task == "score": - self.default_activation_function = \ - get_cross_encoder_activation_function(config.hf_config) - elif config.task == "classify": - self.default_activation_function = nn.Sigmoid() \ - if config.hf_config.num_labels == 1 else nn.Softmax() - else: - raise NotImplementedError(f"task={config.task!r} is not supported" - " with the classification pooler") + self.classification_act_fn = get_classification_activation_function( + config.hf_config) + self.cross_encoder_act_fn = get_cross_encoder_activation_function( + config.hf_config) + + def _get_act_fn(self, use_cross_encoder: bool): + return (self.cross_encoder_act_fn + if use_cross_encoder else self.classification_act_fn) def get_prompt_lens( self, @@ -446,8 +446,29 @@ def forward( # apply classifier once on the full batch if possible pooled_output = self.classifier(pooled_output) - # shape: (batch_size, num_labels) - scores = self.default_activation_function(pooled_output) + if isinstance(pooling_metadata, V0PoolingMetadata): + use_cross_encoder_list = [ + pooling_param.use_cross_encoder + for _, pooling_param in pooling_metadata.seq_groups + ] + else: + assert isinstance(pooled_data, list) + use_cross_encoder_list = [ + pooling_param.use_cross_encoder + for pooling_param in pooling_metadata.pooling_params + ] + + # shape of scores: (batch_size, num_labels) + if all(use_cross_encoder == use_cross_encoder_list[0] + for use_cross_encoder in use_cross_encoder_list): + act_fn = self._get_act_fn(use_cross_encoder_list[0]) + scores = act_fn(pooled_output) + else: + scores = torch.stack([ + self._get_act_fn(use_cross_encoder)(vecs) + for use_cross_encoder, vecs in zip(use_cross_encoder_list, + pooled_data) + ]) pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d6f6d9d1fb59..6e955e1c5121 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -25,8 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.transformers_utils.config import ( - get_cross_encoder_activation_function) from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -462,9 +460,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.default_activation_function = \ - get_cross_encoder_activation_function(config) - self.num_labels = config.num_labels self.bert = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 8fa8b89798d0..048fa827fb2b 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -18,8 +18,6 @@ from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.transformers_utils.config import ( - get_cross_encoder_activation_function) from .bert_with_rope import BertWithRope, JinaRobertaModel from .interfaces import SupportsCrossEncoding, SupportsV0Only @@ -178,9 +176,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.default_activation_function = \ - get_cross_encoder_activation_function(config) - self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index b5c327bdd256..106f3e8b22b7 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -24,12 +24,14 @@ class PoolingParams( """ dimensions: Optional[int] = None + use_cross_encoder: bool = False additional_data: Optional[Any] = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams(dimensions=self.dimensions, + use_cross_encoder=self.use_cross_encoder, additional_data=self.additional_data) def verify(self, model_config: "ModelConfig") -> None: @@ -54,6 +56,7 @@ def verify(self, model_config: "ModelConfig") -> None: def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " + f"use_cross_encoder={self.use_cross_encoder}, " f"additional_metadata={self.additional_data})") def __post_init__(self) -> None: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5c422a9e3fce..9e9dd6425770 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -866,24 +866,26 @@ def try_get_generation_config( return None -def get_cross_encoder_activation_function(config: PretrainedConfig): +def get_classification_activation_function(config: PretrainedConfig): + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + +def get_cross_encoder_activation_function(config: PretrainedConfig): function_name: Optional[str] = None - if hasattr(config, "sentence_transformers") and "activation_fn" in \ - config.sentence_transformers: + if (hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers): function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") and config.sbert_ce_default_activation_function is not None): function_name = config.sbert_ce_default_activation_function if function_name is not None: - assert function_name.startswith("torch.nn.modules."), \ - "Loading of activation functions is restricted to " \ - "torch.nn.modules for security reasons" + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons") return resolve_obj_by_qualname(function_name)() - else: - return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + + return get_classification_activation_function(config) def try_get_safetensors_metadata( From caa3e875edbc68de72e719f82883dcc7dab856d1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Jul 2025 16:49:14 +0000 Subject: [PATCH 2/4] Update online serving Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/protocol.py | 10 ++++++---- vllm/entrypoints/openai/serving_score.py | 10 ++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 93d9c588d8d2..d4db238f456e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] - def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + def to_pooling_params(self, *, use_cross_encoder: bool = False): + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) class RerankRequest(OpenAIBaseModel): @@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] - def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + def to_pooling_params(self, *, use_cross_encoder: bool = False): + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) class RerankDocument(BaseModel): diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 9f333c02ab52..328d4ff0e6c0 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -25,9 +25,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerFast) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -50,7 +48,7 @@ def __init__( async def _embedding_score( self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, texts_1: list[str], texts_2: list[str], request: Union[RerankRequest, ScoreRequest], @@ -141,7 +139,7 @@ async def _embedding_score( async def _cross_encoding_score( self, - tokenizer: Union[AnyTokenizer], + tokenizer: AnyTokenizer, texts_1: list[str], texts_2: list[str], request: Union[RerankRequest, ScoreRequest], @@ -190,7 +188,7 @@ async def _cross_encoding_score( # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - pooling_params = request.to_pooling_params() + pooling_params = request.to_pooling_params(use_cross_encoder=True) for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" From b863f058c59ff7b689b6a58a31f8864726ae9436 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Jul 2025 16:50:53 +0000 Subject: [PATCH 3/4] Fix activation function Signed-off-by: DarkLight1337 --- vllm/transformers_utils/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9e9dd6425770..9ccde292974c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -867,7 +867,7 @@ def try_get_generation_config( def get_classification_activation_function(config: PretrainedConfig): - return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax() def get_cross_encoder_activation_function(config: PretrainedConfig): @@ -885,7 +885,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): "torch.nn.modules for security reasons") return resolve_obj_by_qualname(function_name)() - return get_classification_activation_function(config) + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() def try_get_safetensors_metadata( From 1567dc3d3fdf927e0a7c96bc95976fb5d93a5048 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Jul 2025 16:52:25 +0000 Subject: [PATCH 4/4] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/layers/pooler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 81ea6ecddbaa..d864a915a073 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -452,7 +452,6 @@ def forward( for _, pooling_param in pooling_metadata.seq_groups ] else: - assert isinstance(pooled_data, list) use_cross_encoder_list = [ pooling_param.use_cross_encoder for pooling_param in pooling_metadata.pooling_params @@ -467,7 +466,7 @@ def forward( scores = torch.stack([ self._get_act_fn(use_cross_encoder)(vecs) for use_cross_encoder, vecs in zip(use_cross_encoder_list, - pooled_data) + pooled_output) ]) pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]