From 9a6125d69de4b3172b4cbb457e90d2391760922a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 17 Jul 2025 15:55:45 +0800 Subject: [PATCH 1/6] + implicit conversion Signed-off-by: wang.yuqi --- tests/models/registry.py | 23 ++++++---- tests/models/test_initialization.py | 7 ++- tests/models/test_transformers.py | 35 ++++++++++++++ vllm/config.py | 46 ++++++++++--------- vllm/model_executor/model_loader/utils.py | 29 ++++++++++-- vllm/model_executor/models/adapters.py | 15 +++--- vllm/model_executor/models/gemma.py | 4 -- vllm/model_executor/models/gpt2.py | 56 +---------------------- vllm/model_executor/models/llama.py | 4 -- vllm/model_executor/models/qwen2.py | 4 -- vllm/model_executor/models/qwen3.py | 4 -- vllm/model_executor/models/registry.py | 31 +++++++++---- 12 files changed, 136 insertions(+), 122 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index d2e70e291df3..1c1a8c3c751d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -265,7 +265,6 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), @@ -292,7 +291,6 @@ def check_available_online( # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -311,7 +309,6 @@ def check_available_online( "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 @@ -327,12 +324,6 @@ def check_available_online( _CROSS_ENCODER_EXAMPLE_MODELS = { # [Text-only] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 - "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, - hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 - "classifier_from_token": ["Yes"], # noqa: E501 - "method": "no_post_processing"}), # noqa: E501 - "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 @@ -447,6 +438,19 @@ def check_available_online( "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } +_AUTOMATIC_CONVERTED_MODELS = { + # Use as_seq_cls_model for automatic conversion + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 + "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 +} + _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 @@ -520,3 +524,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) +AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 52005e74ef7e..eda36653614c 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -13,10 +13,13 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test -from .registry import HF_EXAMPLE_MODELS +from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) +@pytest.mark.parametrize( + "model_arch", + HF_EXAMPLE_MODELS.get_supported_archs() + & AUTO_EXAMPLE_MODELS.get_supported_archs()) @create_new_process_for_each_test() def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): """The reason for using create_new_process_for_each_test is to avoid diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b7b99ce41cbb..b87290e96a27 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -138,3 +138,38 @@ def test_quantization( name_0="transformers", name_1="vllm", ) + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + import torch + from transformers import AutoModelForSequenceClassification + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + model_impl="transformers") as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/vllm/config.py b/vllm/config.py index 22f740171369..54848e7a3a23 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -551,7 +551,7 @@ def __post_init__(self) -> None: # For pooling models, self.task is used to indicate the # user-selected task if self.task == "score": - if self.registry.is_cross_encoder_model(self.architectures): + if self._is_classify_task(self.architectures): self.task = "classify" else: self.task = "embed" @@ -806,6 +806,12 @@ def _verify_tokenizer_mode(self) -> None: f"one of {get_args(TokenizerMode)}.") self.tokenizer_mode = tokenizer_mode + def _is_classify_task(self, architectures: list[str]): + for arch in architectures: + if arch.endswith("ForSequenceClassification"): + return True + return self.registry.is_cross_encoder_model(architectures) + def _get_preferred_pooling_task( self, architectures: list[str], @@ -813,14 +819,11 @@ def _get_preferred_pooling_task( model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" - if self.registry.is_cross_encoder_model(architectures): - return "classify" if self.registry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ # Other models follow this pattern - ("ForSequenceClassification", "classify"), ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] @@ -878,11 +881,14 @@ def _get_supported_tasks( self, task_option: TaskOption, ) -> dict[RunnerType, list[_ResolvedTask]]: - return { - "generate": self._get_supported_generation_tasks(task_option), - "pooling": self._get_supported_pooling_tasks(task_option), - "draft": ["draft"] - } + if self._is_classify_task(self.architectures): + return {"generate": [], "pooling": ["classify"], "draft": []} + else: + return { + "generate": self._get_supported_generation_tasks(task_option), + "pooling": self._get_supported_pooling_tasks(task_option), + "draft": ["draft"] + } def _get_supported_runner_types( self, @@ -925,12 +931,16 @@ def _resolve_runner( f"Available tasks for runner={task_runner!r}: " f"{supported_tasks[task_runner]}") + if "classify" in supported_tasks.get("pooling", []): + # When multiple pooling tasks are present, default to + # pooling (eg cross-encoder) for non-standard architectures. + return "pooling" + suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [ ("ForCausalLM", "generate"), ("ForConditionalGeneration", "generate"), ("ChatModel", "generate"), ("LMHeadModel", "generate"), - ("ForSequenceClassification", "pooling"), ("EmbeddingModel", "pooling"), ("RewardModel", "pooling"), ] @@ -940,10 +950,6 @@ def _resolve_runner( if arch.endswith(suffix) and pref_runner in supported_runner_types: return pref_runner - if "classify" in supported_tasks.get("pooling", []): - # When multiple pooling tasks are present, default to - # pooling (eg cross-encoder) for non-standard architectures. - return "pooling" if "generate" in supported_runner_types: return "generate" if "pooling" in supported_runner_types: @@ -1525,7 +1531,7 @@ def is_v1_compatible(self) -> bool: @property def is_matryoshka(self) -> bool: - return (hasattr(self.hf_config, "matryoshka_dimensions") + return (bool(getattr(self.hf_config, "matryoshka_dimensions", None)) or getattr(self.hf_config, "is_matryoshka", False)) @property @@ -1539,13 +1545,11 @@ def use_pad_token(self) -> bool: return getattr(self.hf_config, "use_pad_token", True) def get_and_verify_max_len(self, max_model_len: int): - # For pooling models, the tokenizer's `model_max_length` is often a - # reliable source for the maximum sequence length. However, for - # generative models, this can be incorrect and unduly limit the - # context window (e.g., DeepSeek-R1). Therefore, we only consider - # tokenizer_config for pooling models. + # Consider max_model_len in tokenizer_config only when + # pooling models use absolute position_embedding. tokenizer_config = None - if self.runner_type == "pooling": + if (self.runner_type == "pooling" and getattr( + self.hf_config, "position_embedding_type", "") == "absolute"): tokenizer_config = try_get_tokenizer_config( self.tokenizer, trust_remote_code=self.trust_remote_code, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 8e5f332ba7cc..01444c35156e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -22,7 +22,8 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model) + as_reward_model, + as_seq_cls_model) from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils import is_pin_memory_available @@ -238,9 +239,28 @@ def get_model_architecture( vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_not_supported = not any(arch in vllm_supported_archs for arch in architectures) + + if vllm_not_supported: + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + + assert model_config.task == "classify" + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + causal_lm_arch_vllm_supported = (causal_lm_arch + in vllm_supported_archs) + + if causal_lm_arch_vllm_supported: + architectures = [causal_lm_arch] + vllm_not_supported = False + break + if (model_config.model_impl == ModelImpl.TRANSFORMERS or model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) + logger.debug_once("Resolve transformers arch %s", str(architectures)) elif (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): @@ -248,12 +268,13 @@ def get_model_architecture( model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": + logger.debug_once("Automatic conversion using `as_embedding_model`.") model_cls = as_embedding_model(model_cls) elif model_config.task == "classify": - # Cannot automatically run as_seq_cls_model, - # otherwise it will cause a circular reference on is_cross_encoder_model - pass + logger.debug_once("Automatic conversion using `as_seq_cls_model`.") + model_cls = as_seq_cls_model(model_cls) elif model_config.task == "reward": + logger.debug_once("Automatic conversion using `as_reward_model`.") model_cls = as_reward_model(model_cls) return model_cls, arch diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 5c09ac306052..862431dd2a7e 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -346,13 +346,13 @@ def load_weights_using_from_2_way_softmax( false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) - weight = model.lm_head.weight.data[[true_id]].to( + score_weight = model.lm_head.weight.data[[true_id]].to( torch.float32) - model.lm_head.weight.data[[false_id]].to( torch.float32) param = model.score.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, weight) + weight_loader(param, score_weight) del model.lm_head loaded_weights.add("score.weight") @@ -365,6 +365,8 @@ def load_weights_no_post_processing(model, torch.Tensor]]): from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -372,8 +374,6 @@ def load_weights_no_post_processing(model, tokens = cast(list[int], tokens) assert len(tokens) > 0 - device = model.score.weight.device - if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: @@ -391,8 +391,11 @@ def load_weights_no_post_processing(model, trust_remote_code=model_config.trust_remote_code) token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] - score_weight = model.lm_head.weight.data[token_ids].to(device) - model.score.weight.data.copy_(score_weight) + score_weight = model.lm_head.weight.data[token_ids] + + param = model.score.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, score_weight) del model.lm_head loaded_weights.add("score.weight") diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bc8179f886fd..59c3102add4c 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -426,6 +425,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 27021550f998..fd3decbaebec 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -40,11 +40,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors -from ..layers.pooler import Pooler, PoolingType from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -320,58 +318,6 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) -class GPT2ForSequenceClassification(nn.Module): - """GPT2 Model for sequence classification. - - This class expands GPT2Model with pooling and score functions - last token - is being used for classification. - - Attributes: - transformer: An instance of GPT2Model used for forward operations. - score: A layer for calculating logits. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - self.transformer = GPT2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) - pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - hidden_states = self.transformer( - input_ids=input_ids, - position_ids=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - logits = self.score(hidden_states) - return logits - - def _add_transformer_prefix( weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[tuple[str, torch.Tensor]]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2434ac9d205d..48ec611df12d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,7 +49,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -646,6 +645,3 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - -LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4b..23f65b99c22c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,7 +50,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -496,6 +495,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index de99a76f2897..393ce41a91a0 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -44,7 +44,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model @@ -320,6 +319,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index bc936500bdc8..ec6e5af7891e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -138,7 +138,6 @@ "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), - "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"), @@ -181,10 +180,6 @@ "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] - "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 - "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 - "LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501 "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } @@ -275,7 +270,7 @@ ] -@dataclass(frozen=True) +@dataclass() class _ModelInfo: architecture: str is_text_generation_model: bool @@ -461,10 +456,19 @@ def _try_load_model_cls(self, return _try_load_model_cls(model_arch, self.models[model_arch]) def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: - if model_arch not in self.models: - return None + if model_arch in self.models: + return _try_inspect_model_cls(model_arch, self.models[model_arch]) - return _try_inspect_model_cls(model_arch, self.models[model_arch]) + if model_arch.endswith("ForSequenceClassification"): + arch = model_arch.replace("ForSequenceClassification", + "ForCausalLM") + if arch not in self.models: + return None + info = _try_inspect_model_cls(model_arch, self.models[arch]) + info.supports_cross_encoding = True + return info + + return None def _normalize_archs( self, @@ -479,6 +483,15 @@ def _normalize_archs( normalized_arch = list( filter(lambda model: model in self.models, architectures)) + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + if causal_lm_arch in self.models: + normalized_arch.append(arch) + # make sure Transformers backend is put at the last as a fallback if len(normalized_arch) != len(architectures): normalized_arch.append("TransformersForCausalLM") From 043cdfe98ba1c487356532153135378a7e949303 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 17 Jul 2025 16:10:45 +0800 Subject: [PATCH 2/6] + test_implicit_converted_models Signed-off-by: wang.yuqi --- tests/models/test_initialization.py | 32 +++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index eda36653614c..14d243012b2f 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -13,23 +13,21 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test -from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels -@pytest.mark.parametrize( - "model_arch", - HF_EXAMPLE_MODELS.get_supported_archs() - & AUTO_EXAMPLE_MODELS.get_supported_archs()) @create_new_process_for_each_test() -def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): - """The reason for using create_new_process_for_each_test is to avoid - the WARNING: - "We must use the 'spawn' multiprocessing start method. Overriding +def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, + EXAMPLE_MODELS: HfExampleModels): + """The reason for using create_new_process_for_each_test is to avoid + the WARNING: + "We must use the 'spawn' multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'." - The spawn process causes the _initialize_kv_caches_v1 function below to + The spawn process causes the _initialize_kv_caches_v1 function below to become ineffective. """ - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + + model_info = EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") @@ -130,3 +128,15 @@ def _initialize_kv_caches_v1(self, vllm_config): load_format="dummy", hf_overrides=hf_overrides, ) + + +@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) +def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) + + +@pytest.mark.parametrize("model_arch", + AUTO_EXAMPLE_MODELS.get_supported_archs()) +def test_implicit_converted_models(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS) From 14f64909161ce9c97147d2fb60c2c9c51fe6fa68 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 17 Jul 2025 16:25:20 +0800 Subject: [PATCH 3/6] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/registry.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ec6e5af7891e..f454ea7dab75 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -460,11 +460,19 @@ def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: return _try_inspect_model_cls(model_arch, self.models[model_arch]) if model_arch.endswith("ForSequenceClassification"): - arch = model_arch.replace("ForSequenceClassification", - "ForCausalLM") - if arch not in self.models: + causal_lm_arch = model_arch.replace("ForSequenceClassification", + "ForCausalLM") + if causal_lm_arch not in self.models: return None - info = _try_inspect_model_cls(model_arch, self.models[arch]) + + info = _try_inspect_model_cls(causal_lm_arch, + self.models[causal_lm_arch]) + + # Create a copy to avoid mutating the cached object + import copy + info = copy.copy(info) + + info.architecture = model_arch info.supports_cross_encoding = True return info From 04f9256d3e34d3dbd24f2c4a8522538c2464708b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 17 Jul 2025 16:55:18 +0800 Subject: [PATCH 4/6] - gpt2 Signed-off-by: wang.yuqi --- tests/models/registry.py | 2 +- vllm/model_executor/models/gpt2.py | 56 +++++++++++++++++++++++++- vllm/model_executor/models/registry.py | 1 + 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 1c1a8c3c751d..35d43324a557 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -291,6 +291,7 @@ def check_available_online( # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 + "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -445,7 +446,6 @@ def check_available_online( hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fd3decbaebec..27021550f998 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -40,9 +40,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput +from ..layers.pooler import Pooler, PoolingType from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -318,6 +320,58 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) +class GPT2ForSequenceClassification(nn.Module): + """GPT2 Model for sequence classification. + + This class expands GPT2Model with pooling and score functions - last token + is being used for classification. + + Attributes: + transformer: An instance of GPT2Model used for forward operations. + score: A layer for calculating logits. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.transformer = GPT2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "gpt2")) + self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.transformer( + input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + logits = self.score(hidden_states) + return logits + + def _add_transformer_prefix( weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[tuple[str, torch.Tensor]]: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f454ea7dab75..4a91948ecaed 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -138,6 +138,7 @@ "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"), From 63166f2bffc426504682b76705474e13824a13f8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 17 Jul 2025 17:03:45 +0800 Subject: [PATCH 5/6] + _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS Signed-off-by: wang.yuqi --- tests/models/registry.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 35d43324a557..85ca73ff3c6c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -291,7 +291,6 @@ def check_available_online( # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -322,14 +321,29 @@ def check_available_online( is_available_online=False), # noqa: E501 } -_CROSS_ENCODER_EXAMPLE_MODELS = { - # [Text-only] +_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { + # [Decoder-only] + "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 + + # [Cross-encoder] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 } +_AUTOMATIC_CONVERTED_MODELS = { + # Use as_seq_cls_model for automatic conversion + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 +} + _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), @@ -439,17 +453,6 @@ def check_available_online( "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } -_AUTOMATIC_CONVERTED_MODELS = { - # Use as_seq_cls_model for automatic conversion - "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, - hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 - "classifier_from_token": ["Yes"], # noqa: E501 - "method": "no_post_processing"}), # noqa: E501 - "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 -} _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", @@ -491,7 +494,7 @@ def check_available_online( _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, - **_CROSS_ENCODER_EXAMPLE_MODELS, + **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, **_TRANSFORMERS_MODELS, From e895aa9ab6b643400c512461ab9960ed86855169 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 18 Jul 2025 11:08:50 +0800 Subject: [PATCH 6/6] update Signed-off-by: wang.yuqi --- vllm/model_executor/model_loader/utils.py | 9 +++++---- vllm/model_executor/models/registry.py | 15 +++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 01444c35156e..190d1f006bc4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -251,11 +251,12 @@ def get_model_architecture( "ForCausalLM") causal_lm_arch_vllm_supported = (causal_lm_arch in vllm_supported_archs) + if not causal_lm_arch_vllm_supported: + continue - if causal_lm_arch_vllm_supported: - architectures = [causal_lm_arch] - vllm_not_supported = False - break + architectures = [causal_lm_arch] + vllm_not_supported = False + break if (model_config.model_impl == ModelImpl.TRANSFORMERS or model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b620aa21fed0..fd831727ab2f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,7 +12,7 @@ import tempfile from abc import ABC, abstractmethod from collections.abc import Set -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache from typing import Callable, Optional, TypeVar, Union @@ -272,7 +272,7 @@ ] -@dataclass() +@dataclass(frozen=True) class _ModelInfo: architecture: str is_text_generation_model: bool @@ -470,12 +470,11 @@ def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: info = _try_inspect_model_cls(causal_lm_arch, self.models[causal_lm_arch]) - # Create a copy to avoid mutating the cached object - import copy - info = copy.copy(info) - - info.architecture = model_arch - info.supports_cross_encoding = True + info = _ModelInfo(**dict( + asdict(info), **{ + "architecture": model_arch, + "supports_cross_encoding": True + })) return info return None