Skip to content

Commit c18b3b8

Browse files
[Bugfix] Add use_cross_encoder flag to use correct activation in ClassifierPooler (vllm-project#20527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 9528e3a commit c18b3b8

File tree

8 files changed

+56
-41
lines changed

8 files changed

+56
-41
lines changed

vllm/entrypoints/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def _cross_encoding_score(
12041204

12051205
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
12061206

1207-
pooling_params = PoolingParams()
1207+
pooling_params = PoolingParams(use_cross_encoder=True)
12081208

12091209
tokenization_kwargs: dict[str, Any] = {}
12101210
_validate_truncation_size(self.llm_engine.model_config.max_model_len,

vllm/entrypoints/openai/protocol.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel):
11561156

11571157
# --8<-- [end:score-extra-params]
11581158

1159-
def to_pooling_params(self):
1160-
return PoolingParams(additional_data=self.additional_data)
1159+
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1160+
return PoolingParams(use_cross_encoder=use_cross_encoder,
1161+
additional_data=self.additional_data)
11611162

11621163

11631164
class RerankRequest(OpenAIBaseModel):
@@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel):
11821183

11831184
# --8<-- [end:rerank-extra-params]
11841185

1185-
def to_pooling_params(self):
1186-
return PoolingParams(additional_data=self.additional_data)
1186+
def to_pooling_params(self, *, use_cross_encoder: bool = False):
1187+
return PoolingParams(use_cross_encoder=use_cross_encoder,
1188+
additional_data=self.additional_data)
11871189

11881190

11891191
class RerankDocument(BaseModel):

vllm/entrypoints/openai/serving_score.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from vllm.lora.request import LoRARequest
2626
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
2727
from vllm.prompt_adapter.request import PromptAdapterRequest
28-
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
29-
PreTrainedTokenizer,
30-
PreTrainedTokenizerFast)
28+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
3129
from vllm.utils import make_async, merge_async_iterators
3230

3331
logger = init_logger(__name__)
@@ -50,7 +48,7 @@ def __init__(
5048

5149
async def _embedding_score(
5250
self,
53-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
51+
tokenizer: AnyTokenizer,
5452
texts_1: list[str],
5553
texts_2: list[str],
5654
request: Union[RerankRequest, ScoreRequest],
@@ -141,7 +139,7 @@ async def _embedding_score(
141139

142140
async def _cross_encoding_score(
143141
self,
144-
tokenizer: Union[AnyTokenizer],
142+
tokenizer: AnyTokenizer,
145143
texts_1: list[str],
146144
texts_2: list[str],
147145
request: Union[RerankRequest, ScoreRequest],
@@ -190,7 +188,7 @@ async def _cross_encoding_score(
190188
# Schedule the request and get the result generator.
191189
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
192190

193-
pooling_params = request.to_pooling_params()
191+
pooling_params = request.to_pooling_params(use_cross_encoder=True)
194192

195193
for i, engine_prompt in enumerate(engine_prompts):
196194
request_id_item = f"{request_id}-{i}"

vllm/model_executor/layers/pooler.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.model_executor.pooling_metadata import PoolingTensors
1616
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
1717
from vllm.transformers_utils.config import (
18+
get_classification_activation_function,
1819
get_cross_encoder_activation_function)
1920
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2021

@@ -388,15 +389,14 @@ def __init__(
388389
self.classifier = classifier
389390
self.pooler = pooler
390391

391-
if config.task == "score":
392-
self.default_activation_function = \
393-
get_cross_encoder_activation_function(config.hf_config)
394-
elif config.task == "classify":
395-
self.default_activation_function = nn.Sigmoid() \
396-
if config.hf_config.num_labels == 1 else nn.Softmax()
397-
else:
398-
raise NotImplementedError(f"task={config.task!r} is not supported"
399-
" with the classification pooler")
392+
self.classification_act_fn = get_classification_activation_function(
393+
config.hf_config)
394+
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
395+
config.hf_config)
396+
397+
def _get_act_fn(self, use_cross_encoder: bool):
398+
return (self.cross_encoder_act_fn
399+
if use_cross_encoder else self.classification_act_fn)
400400

401401
def get_prompt_lens(
402402
self,
@@ -446,8 +446,28 @@ def forward(
446446
# apply classifier once on the full batch if possible
447447
pooled_output = self.classifier(pooled_output)
448448

449-
# shape: (batch_size, num_labels)
450-
scores = self.default_activation_function(pooled_output)
449+
if isinstance(pooling_metadata, V0PoolingMetadata):
450+
use_cross_encoder_list = [
451+
pooling_param.use_cross_encoder
452+
for _, pooling_param in pooling_metadata.seq_groups
453+
]
454+
else:
455+
use_cross_encoder_list = [
456+
pooling_param.use_cross_encoder
457+
for pooling_param in pooling_metadata.pooling_params
458+
]
459+
460+
# shape of scores: (batch_size, num_labels)
461+
if all(use_cross_encoder == use_cross_encoder_list[0]
462+
for use_cross_encoder in use_cross_encoder_list):
463+
act_fn = self._get_act_fn(use_cross_encoder_list[0])
464+
scores = act_fn(pooled_output)
465+
else:
466+
scores = torch.stack([
467+
self._get_act_fn(use_cross_encoder)(vecs)
468+
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
469+
pooled_output)
470+
])
451471

452472
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
453473
return PoolerOutput(outputs=pooled_outputs)

vllm/model_executor/models/bert.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2626
from vllm.model_executor.pooling_metadata import PoolingMetadata
2727
from vllm.sequence import IntermediateTensors, PoolerOutput
28-
from vllm.transformers_utils.config import (
29-
get_cross_encoder_activation_function)
3028

3129
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
3230
from .utils import WeightsMapper, maybe_prefix
@@ -462,9 +460,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
462460
super().__init__()
463461
config = vllm_config.model_config.hf_config
464462

465-
self.default_activation_function = \
466-
get_cross_encoder_activation_function(config)
467-
468463
self.num_labels = config.num_labels
469464
self.bert = BertModel(vllm_config=vllm_config,
470465
prefix=maybe_prefix(prefix, "bert"),

vllm/model_executor/models/roberta.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
1919
from vllm.model_executor.pooling_metadata import PoolingMetadata
2020
from vllm.sequence import IntermediateTensors, PoolerOutput
21-
from vllm.transformers_utils.config import (
22-
get_cross_encoder_activation_function)
2321

2422
from .bert_with_rope import BertWithRope, JinaRobertaModel
2523
from .interfaces import SupportsCrossEncoding, SupportsV0Only
@@ -178,9 +176,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
178176
super().__init__()
179177
config = vllm_config.model_config.hf_config
180178

181-
self.default_activation_function = \
182-
get_cross_encoder_activation_function(config)
183-
184179
self.num_labels = config.num_labels
185180
self.roberta = BertModel(vllm_config=vllm_config,
186181
prefix=maybe_prefix(prefix, "bert"),

vllm/pooling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ class PoolingParams(
2424
"""
2525

2626
dimensions: Optional[int] = None
27+
use_cross_encoder: bool = False
2728
additional_data: Optional[Any] = None
2829
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
2930

3031
def clone(self) -> "PoolingParams":
3132
"""Returns a deep copy of the PoolingParams instance."""
3233
return PoolingParams(dimensions=self.dimensions,
34+
use_cross_encoder=self.use_cross_encoder,
3335
additional_data=self.additional_data)
3436

3537
def verify(self, model_config: "ModelConfig") -> None:
@@ -54,6 +56,7 @@ def verify(self, model_config: "ModelConfig") -> None:
5456
def __repr__(self) -> str:
5557
return (f"PoolingParams("
5658
f"dimensions={self.dimensions}, "
59+
f"use_cross_encoder={self.use_cross_encoder}, "
5760
f"additional_metadata={self.additional_data})")
5861

5962
def __post_init__(self) -> None:

vllm/transformers_utils/config.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -866,24 +866,26 @@ def try_get_generation_config(
866866
return None
867867

868868

869-
def get_cross_encoder_activation_function(config: PretrainedConfig):
869+
def get_classification_activation_function(config: PretrainedConfig):
870+
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
871+
870872

873+
def get_cross_encoder_activation_function(config: PretrainedConfig):
871874
function_name: Optional[str] = None
872-
if hasattr(config, "sentence_transformers") and "activation_fn" in \
873-
config.sentence_transformers:
875+
if (hasattr(config, "sentence_transformers")
876+
and "activation_fn" in config.sentence_transformers):
874877
function_name = config.sentence_transformers["activation_fn"]
875-
876878
elif (hasattr(config, "sbert_ce_default_activation_function")
877879
and config.sbert_ce_default_activation_function is not None):
878880
function_name = config.sbert_ce_default_activation_function
879881

880882
if function_name is not None:
881-
assert function_name.startswith("torch.nn.modules."), \
882-
"Loading of activation functions is restricted to " \
883-
"torch.nn.modules for security reasons"
883+
assert function_name.startswith("torch.nn.modules."), (
884+
"Loading of activation functions is restricted to "
885+
"torch.nn.modules for security reasons")
884886
return resolve_obj_by_qualname(function_name)()
885-
else:
886-
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
887+
888+
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
887889

888890

889891
def try_get_safetensors_metadata(

0 commit comments

Comments
 (0)