Skip to content

[Model] Re-add the implicit conversion feature for as_seq_cls_model #21103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
Expand All @@ -292,7 +291,6 @@ def check_available_online(
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
Expand All @@ -311,7 +309,6 @@ def check_available_online(
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
Expand All @@ -324,20 +321,29 @@ def check_available_online(
is_available_online=False), # noqa: E501
}

_CROSS_ENCODER_EXAMPLE_MODELS = {
# [Text-only]
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Decoder-only]
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501

# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
}

_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
}

_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
Expand Down Expand Up @@ -447,6 +453,7 @@ def check_available_online(
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
}


_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
Expand Down Expand Up @@ -487,7 +494,7 @@ def check_available_online(
_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_CROSS_ENCODER_EXAMPLE_MODELS,
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
**_TRANSFORMERS_MODELS,
Expand Down Expand Up @@ -520,3 +527,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)
29 changes: 21 additions & 8 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
from vllm.v1.engine.core import EngineCore as V1EngineCore

from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS
from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
@create_new_process_for_each_test()
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"""The reason for using create_new_process_for_each_test is to avoid
the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
EXAMPLE_MODELS: HfExampleModels):
"""The reason for using create_new_process_for_each_test is to avoid
the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
The spawn process causes the _initialize_kv_caches_v1 function below to
The spawn process causes the _initialize_kv_caches_v1 function below to
become ineffective.
"""
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)

model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

Expand Down Expand Up @@ -127,3 +128,15 @@ def _initialize_kv_caches_v1(self, vllm_config):
load_format="dummy",
hf_overrides=hf_overrides,
)


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)


@pytest.mark.parametrize("model_arch",
AUTO_EXAMPLE_MODELS.get_supported_archs())
def test_implicit_converted_models(model_arch: str,
monkeypatch: pytest.MonkeyPatch):
can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)
35 changes: 35 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,38 @@ def test_quantization(
name_0="transformers",
name_1="vllm",
)


@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification

with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)

assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)
46 changes: 25 additions & 21 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def __post_init__(self) -> None:
# For pooling models, self.task is used to indicate the
# user-selected task
if self.task == "score":
if self.registry.is_cross_encoder_model(self.architectures):
if self._is_classify_task(self.architectures):
self.task = "classify"
else:
self.task = "embed"
Expand Down Expand Up @@ -806,21 +806,24 @@ def _verify_tokenizer_mode(self) -> None:
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode

def _is_classify_task(self, architectures: list[str]):
for arch in architectures:
if arch.endswith("ForSequenceClassification"):
return True
return self.registry.is_cross_encoder_model(architectures)

def _get_preferred_pooling_task(
self,
architectures: list[str],
) -> _ResolvedTask:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"

suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForSequenceClassification", "classify"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
Expand Down Expand Up @@ -878,11 +881,14 @@ def _get_supported_tasks(
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
if self._is_classify_task(self.architectures):
return {"generate": [], "pooling": ["classify"], "draft": []}
else:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}

def _get_supported_runner_types(
self,
Expand Down Expand Up @@ -925,12 +931,16 @@ def _resolve_runner(
f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")

if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"

suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
Expand All @@ -940,10 +950,6 @@ def _resolve_runner(
if arch.endswith(suffix) and pref_runner in supported_runner_types:
return pref_runner

if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
Expand Down Expand Up @@ -1525,7 +1531,7 @@ def is_v1_compatible(self) -> bool:

@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
or getattr(self.hf_config, "is_matryoshka", False))

@property
Expand All @@ -1539,13 +1545,11 @@ def use_pad_token(self) -> bool:
return getattr(self.hf_config, "use_pad_token", True)

def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a
# reliable source for the maximum sequence length. However, for
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
tokenizer_config = None
if self.runner_type == "pooling":
if (self.runner_type == "pooling" and getattr(
self.hf_config, "position_embedding_type", "") == "absolute"):
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,
Expand Down
29 changes: 25 additions & 4 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model)
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -238,22 +239,42 @@ def get_model_architecture(
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)

if vllm_not_supported:
# try automatic conversion in adapters.py
for arch in architectures:
if not arch.endswith("ForSequenceClassification"):
continue

assert model_config.task == "classify"
causal_lm_arch = arch.replace("ForSequenceClassification",
"ForCausalLM")
causal_lm_arch_vllm_supported = (causal_lm_arch
in vllm_supported_archs)

if causal_lm_arch_vllm_supported:
architectures = [causal_lm_arch]
vllm_not_supported = False
break

if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
logger.debug_once("Resolve transformers arch %s", str(architectures))
elif (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
logger.debug_once("Automatic conversion using `as_embedding_model`.")
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
# Cannot automatically run as_seq_cls_model,
# otherwise it will cause a circular reference on is_cross_encoder_model
pass
logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
model_cls = as_seq_cls_model(model_cls)
elif model_config.task == "reward":
logger.debug_once("Automatic conversion using `as_reward_model`.")
model_cls = as_reward_model(model_cls)

return model_cls, arch
Expand Down
15 changes: 9 additions & 6 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,13 @@ def load_weights_using_from_2_way_softmax(

false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[[true_id]].to(
score_weight = model.lm_head.weight.data[[true_id]].to(
torch.float32) - model.lm_head.weight.data[[false_id]].to(
torch.float32)

param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
weight_loader(param, score_weight)

del model.lm_head
loaded_weights.add("score.weight")
Expand All @@ -365,15 +365,15 @@ def load_weights_no_post_processing(model,
torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader

model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) > 0

device = model.score.weight.device

if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
Expand All @@ -391,8 +391,11 @@ def load_weights_no_post_processing(model,
trust_remote_code=model_config.trust_remote_code)

token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids].to(device)
model.score.weight.data.copy_(score_weight)
score_weight = model.lm_head.weight.data[token_ids]

param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)

del model.lm_head
loaded_weights.add("score.weight")
Expand Down
Loading