Skip to content

Commit e11d011

Browse files
committed
+ as_seq_cls_model
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent f905250 commit e11d011

File tree

7 files changed

+90
-96
lines changed

7 files changed

+90
-96
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ Specified using `--task classify`.
447447
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
448448

449449
If your model is not in the above list, we will try to automatically convert the model using
450-
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
450+
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
451451

452452
#### Sentence Pair Scoring
453453

docs/serving/openai_compatible_server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
379379

380380
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
381381

382-
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
382+
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
383383

384384
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
385385

tests/models/test_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from vllm.model_executor.models import (is_pooling_model,
1010
is_text_generation_model,
1111
supports_multimodal)
12-
from vllm.model_executor.models.adapters import (as_classification_model,
13-
as_embedding_model,
14-
as_reward_model)
12+
from vllm.model_executor.models.adapters import (as_embedding_model,
13+
as_reward_model,
14+
as_seq_cls_model)
1515
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
1616
_SPECULATIVE_DECODING_MODELS,
1717
_TEXT_GENERATION_MODELS,
@@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
3838
assert is_text_generation_model(model_cls)
3939

4040
# All vLLM models should be convertible to a pooling model
41-
assert is_pooling_model(as_classification_model(model_cls))
41+
assert is_pooling_model(as_seq_cls_model(model_cls))
4242
assert is_pooling_model(as_embedding_model(model_cls))
4343
assert is_pooling_model(as_reward_model(model_cls))
4444

vllm/model_executor/model_loader/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from vllm.model_executor.layers.quantization.base_config import (
2222
QuantizationConfig, QuantizeMethodBase)
2323
from vllm.model_executor.models import ModelRegistry
24-
from vllm.model_executor.models.adapters import (as_classification_model,
25-
as_embedding_model,
26-
as_reward_model)
24+
from vllm.model_executor.models.adapters import (as_embedding_model,
25+
as_reward_model,
26+
as_seq_cls_model)
2727
from vllm.utils import is_pin_memory_available
2828

2929
logger = init_logger(__name__)
@@ -244,8 +244,8 @@ def get_model_architecture(
244244
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
245245
if model_config.task == "embed":
246246
model_cls = as_embedding_model(model_cls)
247-
elif model_config.task == "classify":
248-
model_cls = as_classification_model(model_cls)
247+
elif model_config.task in ["classify", "score"]:
248+
model_cls = as_seq_cls_model(model_cls)
249249
elif model_config.task == "reward":
250250
model_cls = as_reward_model(model_cls)
251251

vllm/model_executor/models/adapters.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5-
from typing import TYPE_CHECKING, Any, Optional, TypeVar
5+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
66

77
import torch
88
import torch.nn as nn
@@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
145145
return ModelForEmbedding # type: ignore
146146

147147

148-
def as_classification_model(cls: _T) -> _T:
148+
def as_seq_cls_model(cls: _T) -> _T:
149149
"""
150-
Subclass an existing vLLM model to support classification.
150+
Subclass an existing vLLM model to support classify and score tasks.
151151
152152
By default, the class probabilities are extracted from the softmaxed
153153
hidden state corresponding to the last token.
@@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T:
164164
# Lazy import
165165
from vllm.config import VllmConfig
166166
from vllm.model_executor.layers.linear import RowParallelLinear
167-
from vllm.model_executor.layers.pooler import PoolingType
167+
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
168+
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
169+
from vllm.model_executor.pooling_metadata import PoolingMetadata
168170
from vllm.sequence import IntermediateTensors
169171

170172
from .utils import maybe_prefix
@@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
176178
default_softmax=True,
177179
)
178180

179-
class ModelForClassification(ModelForPooling):
181+
class ModelForSequenceClassification(ModelForPooling,
182+
SupportsCrossEncoding):
180183

181184
def __init__(
182185
self,
@@ -186,10 +189,18 @@ def __init__(
186189
**kwargs: Any,
187190
) -> None:
188191
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
192+
self.config_verify(vllm_config)
189193

190194
config = vllm_config.model_config.hf_config
191195
quant_config = vllm_config.quant_config
192196

197+
self.task = vllm_config.model_config.task
198+
self.pooling_type = (
199+
vllm_config.model_config.pooler_config.pooling_type)
200+
201+
if self.task == "score":
202+
assert config.num_labels == 1
203+
193204
self.score = RowParallelLinear(config.hidden_size,
194205
config.num_labels,
195206
quant_config=quant_config,
@@ -198,24 +209,56 @@ def __init__(
198209
prefix=maybe_prefix(
199210
prefix, "score"))
200211

212+
def config_verify(self, vllm_config):
213+
# Leave an interface for validating and modifying model_config
214+
# for slightly different models
215+
pass
216+
201217
def forward(
202218
self,
203219
input_ids: torch.Tensor,
204220
positions: torch.Tensor,
205221
intermediate_tensors: Optional[IntermediateTensors] = None,
206222
inputs_embeds: Optional[torch.Tensor] = None,
207223
) -> torch.Tensor:
208-
hidden_states = super().forward(input_ids, positions,
209-
intermediate_tensors,
210-
inputs_embeds)
211-
logits, _ = self.score(hidden_states)
212-
return logits
224+
return super().forward(input_ids, positions, intermediate_tensors,
225+
inputs_embeds)
226+
227+
def pooler(
228+
self,
229+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
230+
pooling_metadata: PoolingMetadata,
231+
) -> PoolerOutput:
232+
233+
def get_logits(hidden_states):
234+
if isinstance(hidden_states, list):
235+
logits = [self.score(state)[0] for state in hidden_states]
236+
else:
237+
logits, _ = self.score(hidden_states)
238+
return logits
239+
240+
if self.pooling_type == PoolingType.ALL:
241+
logits = get_logits(hidden_states)
242+
return self._pooler(logits, pooling_metadata)
243+
else:
244+
hidden_states = self._pooler.extract_states(
245+
hidden_states, pooling_metadata)
246+
logits = get_logits(hidden_states)
247+
pooled_data = self._pooler.head(logits, pooling_metadata)
248+
249+
if self.task == "score":
250+
pooled_data = [data.squeeze(-1) for data in pooled_data]
251+
252+
pooled_outputs = [
253+
self._pooler.build_output(data) for data in pooled_data
254+
]
255+
return PoolerOutput(outputs=pooled_outputs)
213256

214257

215-
ModelForClassification.__name__ = \
216-
_get_pooling_model_name(cls.__name__, "ForClassification")
258+
ModelForSequenceClassification.__name__ = \
259+
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
217260

218-
return ModelForClassification # type: ignore
261+
return ModelForSequenceClassification # type: ignore
219262

220263

221264
def as_reward_model(cls: _T) -> _T:

vllm/model_executor/models/qwen3.py

Lines changed: 20 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,14 @@
3838
from vllm.model_executor.layers.linear import (QKVParallelLinear,
3939
RowParallelLinear)
4040
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
4241
from vllm.model_executor.layers.quantization import QuantizationConfig
4342
from vllm.model_executor.layers.rotary_embedding import get_rope
4443
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
45-
from vllm.model_executor.pooling_metadata import PoolingMetadata
4644
from vllm.model_executor.sampling_metadata import SamplingMetadata
47-
from vllm.sequence import IntermediateTensors, PoolerOutput
45+
from vllm.sequence import IntermediateTensors
4846

49-
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
47+
from .adapters import as_seq_cls_model
48+
from .interfaces import SupportsLoRA, SupportsPP
5049
from .qwen2 import Qwen2MLP as Qwen3MLP
5150
from .qwen2 import Qwen2Model
5251
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
@@ -323,69 +322,31 @@ def load_weights(self, weights: Iterable[tuple[str,
323322
return loader.load_weights(weights)
324323

325324

326-
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
327-
SupportsCrossEncoding):
325+
class Qwen3ForSequenceClassification(as_seq_cls_model(Qwen3ForCausalLM)):
328326

329327
def __init__(
330328
self,
331329
vllm_config: "VllmConfig",
332330
prefix: str = "",
333331
) -> None:
334-
super().__init__()
332+
super().__init__(vllm_config=vllm_config, prefix=prefix)
335333

334+
def config_verify(self, vllm_config: "VllmConfig"):
336335
config = vllm_config.model_config.hf_config
337-
quant_config = vllm_config.quant_config
338-
pooler_config = vllm_config.model_config.pooler_config
339-
340-
self.vllm_config = vllm_config
341-
self.config = config
342-
self.quant_config = quant_config
343-
self.prefix = prefix
344-
self.model = Qwen3Model(vllm_config=vllm_config,
345-
prefix=maybe_prefix(prefix, "model"))
346-
self.score = RowParallelLinear(config.hidden_size,
347-
config.num_labels,
348-
quant_config=quant_config,
349-
input_is_parallel=False,
350-
bias=False,
351-
prefix=maybe_prefix(prefix, "score"))
352-
353-
self._pooler = Pooler.from_config_with_defaults(
354-
pooler_config,
355-
pooling_type=PoolingType.LAST,
356-
normalize=False,
357-
softmax=True)
358-
359-
def forward(
360-
self,
361-
input_ids: torch.Tensor,
362-
positions: torch.Tensor,
363-
intermediate_tensors: Optional[IntermediateTensors] = None,
364-
inputs_embeds: Optional[torch.Tensor] = None,
365-
) -> torch.Tensor:
366-
return self.model(input_ids=input_ids,
367-
positions=positions,
368-
inputs_embeds=inputs_embeds,
369-
intermediate_tensors=intermediate_tensors)
370336

371-
def pooler(
372-
self,
373-
hidden_states: torch.Tensor,
374-
pooling_metadata: PoolingMetadata,
375-
) -> Optional[PoolerOutput]:
376-
hidden_states = self._pooler.extract_states(hidden_states,
377-
pooling_metadata)
337+
is_original_qwen3_reranker = getattr(config,
338+
"is_original_qwen3_reranker",
339+
False)
378340

379-
if isinstance(hidden_states, list):
380-
logits = [self.score(state)[0] for state in hidden_states]
381-
else:
382-
logits, _ = self.score(hidden_states)
341+
if not is_original_qwen3_reranker:
342+
return
383343

384-
pooled_data = self._pooler.head(logits, pooling_metadata)
385-
pooled_outputs = [
386-
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
387-
]
388-
return PoolerOutput(outputs=pooled_outputs)
344+
tokens = getattr(config, "classifier_from_token", None)
345+
assert tokens is not None and len(tokens) == 2, \
346+
("Try loading the original Qwen3 Reranker?, see: "
347+
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
348+
config.num_labels = 1
349+
self.vllm_config = vllm_config
389350

390351
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
391352
is_original_qwen3_reranker = getattr(self.config,
@@ -400,22 +361,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
400361

401362
def load_weights_from_original_qwen3_reranker(
402363
self, weights: Iterable[tuple[str, torch.Tensor]]):
403-
tokens = getattr(self.config, "classifier_from_token", None)
404-
assert tokens is not None and len(tokens) == 2, \
405-
("Try loading the original Qwen3 Reranker?, see: "
406-
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
407364

408-
self.config.num_labels = 1
409365
model_config = self.vllm_config.model_config
410-
366+
tokens = getattr(self.config, "classifier_from_token", None)
411367
device = self.score.weight.device
412-
self.score = RowParallelLinear(self.config.hidden_size,
413-
self.config.num_labels,
414-
quant_config=self.quant_config,
415-
input_is_parallel=False,
416-
bias=False,
417-
prefix=maybe_prefix(
418-
self.prefix, "score")).to(device)
419368

420369
if self.config.tie_word_embeddings:
421370
self.lm_head = self.model.embed_tokens
@@ -443,5 +392,6 @@ def load_weights_from_original_qwen3_reranker(
443392
self.score.weight.data.copy_(weight)
444393

445394
del self.lm_head
446-
loaded_weights.add("classifier.weight")
395+
loaded_weights.add("score.weight")
447396
loaded_weights.discard("lm_head.weight")
397+
return loaded_weights

vllm/model_executor/models/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@
157157
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
158158
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
159159
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
160-
# [Auto-converted (see adapters.py)]
161-
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
162160
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
163161
# input and output. I am adding it here because it piggy-backs on embedding
164162
# models for the time being.
@@ -173,6 +171,9 @@
173171
"RobertaForSequenceClassification"),
174172
"ModernBertForSequenceClassification": ("modernbert",
175173
"ModernBertForSequenceClassification"),
174+
# [Auto-converted (see adapters.py)]
175+
"GemmaForSequenceClassification": ("gemma", "GemmaForCausalLM"),
176+
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
176177
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
177178
}
178179

0 commit comments

Comments
 (0)