Skip to content

Commit 207d477

Browse files
noooopWorldExplored
authored andcommitted
[Model] Re-add the implicit conversion feature for as_seq_cls_model (vllm-project#21103)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 9fdfb7f commit 207d477

File tree

11 files changed

+165
-75
lines changed

11 files changed

+165
-75
lines changed

tests/models/registry.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ def check_available_online(
265265
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
266266
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
267267
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
268-
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
269268
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
270269
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
271270
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@@ -292,7 +291,6 @@ def check_available_online(
292291
# [Text-only]
293292
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
294293
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
295-
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
296294
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
297295
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
298296
trust_remote_code=True),
@@ -311,7 +309,6 @@ def check_available_online(
311309
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
312310
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
313311
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
314-
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
315312
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
316313
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
317314
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
@@ -324,20 +321,29 @@ def check_available_online(
324321
is_available_online=False), # noqa: E501
325322
}
326323

327-
_CROSS_ENCODER_EXAMPLE_MODELS = {
328-
# [Text-only]
324+
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
325+
# [Decoder-only]
326+
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
327+
328+
# [Cross-encoder]
329329
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
330-
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
331-
v0_only=True,
332-
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
333-
"classifier_from_token": ["Yes"], # noqa: E501
334-
"method": "no_post_processing"}), # noqa: E501
335-
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
336330
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
337331
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
338332
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
339333
}
340334

335+
_AUTOMATIC_CONVERTED_MODELS = {
336+
# Use as_seq_cls_model for automatic conversion
337+
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
338+
v0_only=True,
339+
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
340+
"classifier_from_token": ["Yes"], # noqa: E501
341+
"method": "no_post_processing"}), # noqa: E501
342+
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
343+
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
344+
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
345+
}
346+
341347
_MULTIMODAL_EXAMPLE_MODELS = {
342348
# [Decoder-only]
343349
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
@@ -449,6 +455,7 @@ def check_available_online(
449455
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
450456
}
451457

458+
452459
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
453460
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
454461
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
@@ -489,7 +496,7 @@ def check_available_online(
489496
_EXAMPLE_MODELS = {
490497
**_TEXT_GENERATION_EXAMPLE_MODELS,
491498
**_EMBEDDING_EXAMPLE_MODELS,
492-
**_CROSS_ENCODER_EXAMPLE_MODELS,
499+
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
493500
**_MULTIMODAL_EXAMPLE_MODELS,
494501
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
495502
**_TRANSFORMERS_MODELS,
@@ -522,3 +529,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
522529

523530

524531
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
532+
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)

tests/models/test_initialization.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@
1313
from vllm.v1.engine.core import EngineCore as V1EngineCore
1414

1515
from ..utils import create_new_process_for_each_test
16-
from .registry import HF_EXAMPLE_MODELS
16+
from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels
1717

1818

19-
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
2019
@create_new_process_for_each_test()
21-
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
22-
"""The reason for using create_new_process_for_each_test is to avoid
23-
the WARNING:
24-
"We must use the 'spawn' multiprocessing start method. Overriding
20+
def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
21+
EXAMPLE_MODELS: HfExampleModels):
22+
"""The reason for using create_new_process_for_each_test is to avoid
23+
the WARNING:
24+
"We must use the 'spawn' multiprocessing start method. Overriding
2525
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
26-
The spawn process causes the _initialize_kv_caches_v1 function below to
26+
The spawn process causes the _initialize_kv_caches_v1 function below to
2727
become ineffective.
2828
"""
29-
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
29+
30+
model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
3031
model_info.check_available_online(on_fail="skip")
3132
model_info.check_transformers_version(on_fail="skip")
3233

@@ -127,3 +128,15 @@ def _initialize_kv_caches_v1(self, vllm_config):
127128
load_format="dummy",
128129
hf_overrides=hf_overrides,
129130
)
131+
132+
133+
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
134+
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
135+
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
136+
137+
138+
@pytest.mark.parametrize("model_arch",
139+
AUTO_EXAMPLE_MODELS.get_supported_archs())
140+
def test_implicit_converted_models(model_arch: str,
141+
monkeypatch: pytest.MonkeyPatch):
142+
can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)

tests/models/test_transformers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,38 @@ def test_quantization(
138138
name_0="transformers",
139139
name_1="vllm",
140140
)
141+
142+
143+
@pytest.mark.parametrize(
144+
"model",
145+
["jason9693/Qwen2.5-1.5B-apeach"],
146+
)
147+
@pytest.mark.parametrize("dtype", ["half"])
148+
def test_classify(
149+
hf_runner,
150+
vllm_runner,
151+
example_prompts,
152+
model: str,
153+
dtype: str,
154+
monkeypatch,
155+
) -> None:
156+
import torch
157+
from transformers import AutoModelForSequenceClassification
158+
159+
with vllm_runner(model,
160+
max_model_len=512,
161+
dtype=dtype,
162+
model_impl="transformers") as vllm_model:
163+
vllm_outputs = vllm_model.classify(example_prompts)
164+
165+
with hf_runner(model,
166+
dtype=dtype,
167+
auto_cls=AutoModelForSequenceClassification) as hf_model:
168+
hf_outputs = hf_model.classify(example_prompts)
169+
170+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
171+
hf_output = torch.tensor(hf_output)
172+
vllm_output = torch.tensor(vllm_output)
173+
174+
assert torch.allclose(hf_output, vllm_output,
175+
1e-3 if dtype == "float" else 1e-2)

vllm/config.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def __post_init__(self) -> None:
551551
# For pooling models, self.task is used to indicate the
552552
# user-selected task
553553
if self.task == "score":
554-
if self.registry.is_cross_encoder_model(self.architectures):
554+
if self._is_classify_task(self.architectures):
555555
self.task = "classify"
556556
else:
557557
self.task = "embed"
@@ -806,21 +806,24 @@ def _verify_tokenizer_mode(self) -> None:
806806
f"one of {get_args(TokenizerMode)}.")
807807
self.tokenizer_mode = tokenizer_mode
808808

809+
def _is_classify_task(self, architectures: list[str]):
810+
for arch in architectures:
811+
if arch.endswith("ForSequenceClassification"):
812+
return True
813+
return self.registry.is_cross_encoder_model(architectures)
814+
809815
def _get_preferred_pooling_task(
810816
self,
811817
architectures: list[str],
812818
) -> _ResolvedTask:
813819
model_id = self.model
814820
if get_pooling_config(model_id, self.revision):
815821
return "embed"
816-
if self.registry.is_cross_encoder_model(architectures):
817-
return "classify"
818822
if self.registry.is_transcription_model(architectures):
819823
return "transcription"
820824

821825
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
822826
# Other models follow this pattern
823-
("ForSequenceClassification", "classify"),
824827
("EmbeddingModel", "embed"),
825828
("RewardModel", "reward"),
826829
]
@@ -878,11 +881,14 @@ def _get_supported_tasks(
878881
self,
879882
task_option: TaskOption,
880883
) -> dict[RunnerType, list[_ResolvedTask]]:
881-
return {
882-
"generate": self._get_supported_generation_tasks(task_option),
883-
"pooling": self._get_supported_pooling_tasks(task_option),
884-
"draft": ["draft"]
885-
}
884+
if self._is_classify_task(self.architectures):
885+
return {"generate": [], "pooling": ["classify"], "draft": []}
886+
else:
887+
return {
888+
"generate": self._get_supported_generation_tasks(task_option),
889+
"pooling": self._get_supported_pooling_tasks(task_option),
890+
"draft": ["draft"]
891+
}
886892

887893
def _get_supported_runner_types(
888894
self,
@@ -925,12 +931,16 @@ def _resolve_runner(
925931
f"Available tasks for runner={task_runner!r}: "
926932
f"{supported_tasks[task_runner]}")
927933

934+
if "classify" in supported_tasks.get("pooling", []):
935+
# When multiple pooling tasks are present, default to
936+
# pooling (eg cross-encoder) for non-standard architectures.
937+
return "pooling"
938+
928939
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
929940
("ForCausalLM", "generate"),
930941
("ForConditionalGeneration", "generate"),
931942
("ChatModel", "generate"),
932943
("LMHeadModel", "generate"),
933-
("ForSequenceClassification", "pooling"),
934944
("EmbeddingModel", "pooling"),
935945
("RewardModel", "pooling"),
936946
]
@@ -940,10 +950,6 @@ def _resolve_runner(
940950
if arch.endswith(suffix) and pref_runner in supported_runner_types:
941951
return pref_runner
942952

943-
if "classify" in supported_tasks.get("pooling", []):
944-
# When multiple pooling tasks are present, default to
945-
# pooling (eg cross-encoder) for non-standard architectures.
946-
return "pooling"
947953
if "generate" in supported_runner_types:
948954
return "generate"
949955
if "pooling" in supported_runner_types:
@@ -1525,7 +1531,7 @@ def is_v1_compatible(self) -> bool:
15251531

15261532
@property
15271533
def is_matryoshka(self) -> bool:
1528-
return (hasattr(self.hf_config, "matryoshka_dimensions")
1534+
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
15291535
or getattr(self.hf_config, "is_matryoshka", False))
15301536

15311537
@property
@@ -1539,13 +1545,11 @@ def use_pad_token(self) -> bool:
15391545
return getattr(self.hf_config, "use_pad_token", True)
15401546

15411547
def get_and_verify_max_len(self, max_model_len: int):
1542-
# For pooling models, the tokenizer's `model_max_length` is often a
1543-
# reliable source for the maximum sequence length. However, for
1544-
# generative models, this can be incorrect and unduly limit the
1545-
# context window (e.g., DeepSeek-R1). Therefore, we only consider
1546-
# tokenizer_config for pooling models.
1548+
# Consider max_model_len in tokenizer_config only when
1549+
# pooling models use absolute position_embedding.
15471550
tokenizer_config = None
1548-
if self.runner_type == "pooling":
1551+
if (self.runner_type == "pooling" and getattr(
1552+
self.hf_config, "position_embedding_type", "") == "absolute"):
15491553
tokenizer_config = try_get_tokenizer_config(
15501554
self.tokenizer,
15511555
trust_remote_code=self.trust_remote_code,

vllm/model_executor/model_loader/utils.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
QuantizationConfig, QuantizeMethodBase)
2323
from vllm.model_executor.models import ModelRegistry
2424
from vllm.model_executor.models.adapters import (as_embedding_model,
25-
as_reward_model)
25+
as_reward_model,
26+
as_seq_cls_model)
2627
from vllm.model_executor.models.interfaces import SupportsQuant
2728
from vllm.utils import is_pin_memory_available
2829

@@ -238,22 +239,43 @@ def get_model_architecture(
238239
vllm_supported_archs = ModelRegistry.get_supported_archs()
239240
vllm_not_supported = not any(arch in vllm_supported_archs
240241
for arch in architectures)
242+
243+
if vllm_not_supported:
244+
# try automatic conversion in adapters.py
245+
for arch in architectures:
246+
if not arch.endswith("ForSequenceClassification"):
247+
continue
248+
249+
assert model_config.task == "classify"
250+
causal_lm_arch = arch.replace("ForSequenceClassification",
251+
"ForCausalLM")
252+
causal_lm_arch_vllm_supported = (causal_lm_arch
253+
in vllm_supported_archs)
254+
if not causal_lm_arch_vllm_supported:
255+
continue
256+
257+
architectures = [causal_lm_arch]
258+
vllm_not_supported = False
259+
break
260+
241261
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
242262
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
243263
architectures = resolve_transformers_arch(model_config, architectures)
264+
logger.debug_once("Resolve transformers arch %s", str(architectures))
244265
elif (model_config.quantization is not None
245266
and model_config.quantization not in mixtral_supported
246267
and "MixtralForCausalLM" in architectures):
247268
architectures = ["QuantMixtralForCausalLM"]
248269

249270
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
250271
if model_config.task == "embed":
272+
logger.debug_once("Automatic conversion using `as_embedding_model`.")
251273
model_cls = as_embedding_model(model_cls)
252274
elif model_config.task == "classify":
253-
# Cannot automatically run as_seq_cls_model,
254-
# otherwise it will cause a circular reference on is_cross_encoder_model
255-
pass
275+
logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
276+
model_cls = as_seq_cls_model(model_cls)
256277
elif model_config.task == "reward":
278+
logger.debug_once("Automatic conversion using `as_reward_model`.")
257279
model_cls = as_reward_model(model_cls)
258280

259281
return model_cls, arch

vllm/model_executor/models/adapters.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
331331

332332
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
333333
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
334-
weight = model.lm_head.weight.data[[true_id]].to(
334+
score_weight = model.lm_head.weight.data[[true_id]].to(
335335
torch.float32) - model.lm_head.weight.data[[false_id]].to(
336336
torch.float32)
337337

338338
param = model.score.weight
339339
weight_loader = getattr(param, "weight_loader", default_weight_loader)
340-
weight_loader(param, weight)
340+
weight_loader(param, score_weight)
341341

342342
del model.lm_head
343343
loaded_weights.add("score.weight")
@@ -350,15 +350,15 @@ def load_weights_no_post_processing(model,
350350
torch.Tensor]]):
351351
from vllm.model_executor.layers.vocab_parallel_embedding import (
352352
ParallelLMHead)
353+
from vllm.model_executor.model_loader.weight_utils import (
354+
default_weight_loader)
353355
from vllm.model_executor.models.utils import AutoWeightsLoader
354356

355357
model_config = model.vllm_config.model_config
356358
tokens = getattr(model.config, "classifier_from_token", [])
357359
tokens = cast(list[int], tokens)
358360
assert len(tokens) > 0
359361

360-
device = model.score.weight.device
361-
362362
if model.config.tie_word_embeddings:
363363
model.lm_head = model.model.embed_tokens
364364
else:
@@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
376376
trust_remote_code=model_config.trust_remote_code)
377377

378378
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
379-
score_weight = model.lm_head.weight.data[token_ids].to(device)
380-
model.score.weight.data.copy_(score_weight)
379+
score_weight = model.lm_head.weight.data[token_ids]
380+
381+
param = model.score.weight
382+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
383+
weight_loader(param, score_weight)
381384

382385
del model.lm_head
383386
loaded_weights.add("score.weight")

0 commit comments

Comments
 (0)