diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 6c4fde5fdfa..2bcafc0c21d 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -176,9 +176,12 @@ def mteb_test_embed_models(hf_runner, max_model_len=None, **vllm_extra_kwargs) as vllm_model: + model_config = vllm_model.model.llm_engine.model_config + if model_info.architecture: - assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) + assert model_info.architecture in model_config.architectures + assert (model_config.model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) @@ -289,6 +292,8 @@ def mteb_test_rerank_models(hf_runner, if model_info.architecture: assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 + assert (model_config.model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 64a8f25220d..6f44d180095 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -2,55 +2,56 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import (CLSEmbedModelInfo, CLSRerankModelInfo, EmbedModelInfo, + RerankModelInfo) from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - EmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), + CLSEmbedModelInfo("BAAI/bge-base-en", + architecture="BertModel", + enable_test=True), + CLSEmbedModelInfo("BAAI/bge-base-zh", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-small-en", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-small-zh", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-large-en", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-large-zh", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-large-zh-noinstruct", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-base-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-base-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-small-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-small-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-large-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("BAAI/bge-large-zh-v1.5", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - enable_test=True), + CLSEmbedModelInfo("BAAI/bge-m3", + architecture="XLMRobertaModel", + enable_test=True), ########## Qwen2Model EmbedModelInfo("BAAI/bge-code-v1", architecture="Qwen2Model", @@ -60,15 +61,15 @@ RERANK_MODELS = [ ########## XLMRobertaForSequenceClassification - RerankModelInfo("BAAI/bge-reranker-base", - architecture="XLMRobertaForSequenceClassification", - enable_test=True), - RerankModelInfo("BAAI/bge-reranker-large", - architecture="XLMRobertaForSequenceClassification", - enable_test=False), - RerankModelInfo("BAAI/bge-reranker-v2-m3", - architecture="XLMRobertaForSequenceClassification", - enable_test=False) + CLSRerankModelInfo("BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + enable_test=True), + CLSRerankModelInfo("BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False), + CLSRerankModelInfo("BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + enable_test=False) ] diff --git a/tests/models/language/pooling/test_classify_auto_prefix_cache_support.py b/tests/models/language/pooling/test_classify_auto_prefix_cache_support.py new file mode 100644 index 00000000000..16babc6dbb4 --- /dev/null +++ b/tests/models/language/pooling/test_classify_auto_prefix_cache_support.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: SIM117 +# Keep Decode-only SequenceClassification models support auto prefix cache +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_decode_only_classify( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + ["Alibaba-NLP/gte-Qwen2-1.5B-instruct"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_encode_only_classify( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + with pytest.raises(RuntimeError): + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + vllm_model.classify(example_prompts) + # Is there any way to capture errors in worker processes? + # NotImplementedError: Encoder self-attention and encoder/decoder + # cross-attention are not implemented for FlashAttentionImpl diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py index 9a33063d7b4..8ef5aedfae8 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling/test_cross_encoder.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import CLSRerankModelInfo, RerankModelInfo +from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - architecture="BertForSequenceClassification"), + CLSRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", + architecture="BertForSequenceClassification"), RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", architecture="Qwen3ForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 0ad54785308..e1d7d6d109a 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -4,47 +4,48 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSEmbedModelInfo, EmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - EmbedModelInfo("thenlper/gte-large", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), + CLSEmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + enable_test=True), + CLSEmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + enable_test=False), ########### NewModel - EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - enable_test=True), + CLSEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + CLSEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + CLSEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), ########### Qwen2ForCausalLM EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", architecture="Qwen2ForCausalLM", enable_test=True), ########## ModernBertModel - EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - architecture="ModernBertModel", - enable_test=True), + CLSEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), ########## Qwen3ForCausalLM EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", architecture="Qwen3ForCausalLM", diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py index d899aaada26..ef0edcd6742 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling/test_intfloat.py @@ -2,34 +2,34 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo +from ...utils import CLSEmbedModelInfo, EmbedModelInfo from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - EmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), + CLSEmbedModelInfo("intfloat/e5-small", + architecture="BertModel", + enable_test=True), + CLSEmbedModelInfo("intfloat/e5-base", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("intfloat/e5-large", + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("intfloat/multilingual-e5-small", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), + CLSEmbedModelInfo("intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + enable_test=True), + CLSEmbedModelInfo("intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False), + CLSEmbedModelInfo("intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False), ] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 9bfe7411e16..2a81cf066b8 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -6,26 +6,26 @@ from vllm import PoolingParams -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import CLSEmbedModelInfo, CLSRerankModelInfo from .embed_utils import (check_embeddings_close, correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - EmbedModelInfo("jinaai/jina-embeddings-v3", - architecture="XLMRobertaModel", - is_matryoshka=True) + CLSEmbedModelInfo("jinaai/jina-embeddings-v3", + architecture="XLMRobertaModel", + is_matryoshka=True) ] RERANK_MODELS = [ - RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification") + CLSRerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", + architecture="XLMRobertaForSequenceClassification") ] @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: + model_info: CLSEmbedModelInfo) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") @@ -38,7 +38,7 @@ def hf_model_callback(model): @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, + model_info: CLSEmbedModelInfo, example_prompts) -> None: def hf_model_callback(model): @@ -53,7 +53,7 @@ def hf_model_callback(model): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: + model_info: CLSRerankModelInfo) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index e16ec239a33..8d0f54f9432 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -3,34 +3,35 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSEmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - enable_test=True) + CLSEmbedModelInfo("nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + enable_test=True), + CLSEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False), + CLSEmbedModelInfo("nomic-ai/CodeRankEmbed", + architecture="NomicBertModel", + enable_test=False), + CLSEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + enable_test=True) ] @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: + model_info: CLSEmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, + model_info: CLSEmbedModelInfo, example_prompts) -> None: correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index d6b5dbd0837..71b358e8f91 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -3,54 +3,55 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSEmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - enable_test=True), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + enable_test=True), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + enable_test=True), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + enable_test=True), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + enable_test=True), + CLSEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + enable_test=True), ] @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: + model_info: CLSEmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, + model_info: CLSEmbedModelInfo, example_prompts) -> None: correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/registry.py b/tests/models/registry.py index 9d3fc8a1b1c..a0e0dd362da 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -265,7 +265,6 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), @@ -292,7 +291,6 @@ def check_available_online( # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -311,7 +309,6 @@ def check_available_online( "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 @@ -327,12 +324,6 @@ def check_available_online( _CROSS_ENCODER_EXAMPLE_MODELS = { # [Text-only] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 - "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, - hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 - "classifier_from_token": ["Yes"], # noqa: E501 - "method": "no_post_processing"}), # noqa: E501 - "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 @@ -446,6 +437,19 @@ def check_available_online( "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } +_AUTOMATIC_CONVERTED_MODELS = { + # Use as_seq_cls_model for automatic conversion + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 + "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 +} + _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 @@ -513,4 +517,5 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) +AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index ea6a2cc37cc..f142a3f6adb 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -13,10 +13,13 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test -from .registry import HF_EXAMPLE_MODELS +from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) +@pytest.mark.parametrize( + "model_arch", + HF_EXAMPLE_MODELS.get_supported_archs() + & AUTO_EXAMPLE_MODELS.get_supported_archs()) @create_new_process_for_each_test() def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): """The reason for using create_new_process_for_each_test is to avoid diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b7b99ce41cb..b87290e96a2 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -138,3 +138,38 @@ def test_quantization( name_0="transformers", name_1="vllm", ) + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + import torch + from transformers import AutoModelForSequenceClassification + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + model_impl="transformers") as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/utils.py b/tests/models/utils.py index cdf8d02df73..495550d16a3 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -335,11 +335,21 @@ class EmbedModelInfo(NamedTuple): matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + default_pooling_type: str = "LAST" enable_test: bool = True +class CLSEmbedModelInfo(EmbedModelInfo): + default_pooling_type: str = "CLS" + + class RerankModelInfo(NamedTuple): name: str architecture: str = "" dtype: str = "auto" + default_pooling_type: str = "LAST" enable_test: bool = True + + +class CLSRerankModelInfo(RerankModelInfo): + default_pooling_type: str = "CLS" diff --git a/tests/test_config.py b/tests/test_config.py index 015baef9181..253d4d7e33e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -292,6 +292,26 @@ def test_get_pooling_config_from_args(): assert asdict(pooling_config) == asdict(override_pooler_config) +@pytest.mark.parametrize( + ("model_id", "default_pooling_type"), + [ + ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST"), # LLM + ("BAAI/bge-base-en", "CLS") # BertModel + ]) +def test_default_pooling_type(model_id, default_pooling_type): + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + assert model_config.model_info.default_pooling_type == default_pooling_type + + @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_bert_tokenization_sentence_transformer_config(): diff --git a/vllm/config.py b/vllm/config.py index dc8acad25a3..e492c52fcf8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -551,7 +551,7 @@ def __post_init__(self) -> None: # For pooling models, self.task is used to indicate the # user-selected task if self.task == "score": - if self.registry.is_cross_encoder_model(self.architectures): + if self._is_classify_task(self.architectures): self.task = "classify" else: self.task = "embed" @@ -772,14 +772,34 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) - if self.is_matryoshka: - if pooler_config.normalize is None: + # set default pooler config + if pooler_config.pooling_type is None: + if self.task == "reward": + pooler_config.pooling_type = "ALL" + else: + default_pooling_type = self.model_info.default_pooling_type + pooler_config.pooling_type = default_pooling_type + if pooler_config.normalize is None: + if self.task in ["classify", "reward", "pooling"]: + pooler_config.normalize = False + elif self.task == "embed": pooler_config.normalize = True - elif not pooler_config.normalize: - raise ValueError( - "`normalize` must be enabled (set to True) " - "for models that are compatible with " - "Matryoshka Representation.") + else: + raise ValueError(f"Pooling runner does not " + f"support {self.task} task.") + if pooler_config.softmax is None: + if self.task == "classify": + pooler_config.softmax = True + elif self.task in ["embed", "reward", "pooling"]: + pooler_config.softmax = False + else: + raise ValueError(f"Pooling runner does not " + f"support {self.task} task.") + + if self.is_matryoshka and not pooler_config.normalize: + raise ValueError("`normalize` must be enabled (set to True) " + "for models that are compatible with " + "Matryoshka Representation.") return pooler_config @@ -806,6 +826,12 @@ def _verify_tokenizer_mode(self) -> None: f"one of {get_args(TokenizerMode)}.") self.tokenizer_mode = tokenizer_mode + def _is_classify_task(self, architectures: list[str]): + for arch in architectures: + if arch.endswith("ForSequenceClassification"): + return True + return self.registry.is_cross_encoder_model(architectures) + def _get_preferred_pooling_task( self, architectures: list[str], @@ -813,14 +839,11 @@ def _get_preferred_pooling_task( model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" - if self.registry.is_cross_encoder_model(architectures): - return "classify" if self.registry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ # Other models follow this pattern - ("ForSequenceClassification", "classify"), ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] @@ -878,11 +901,14 @@ def _get_supported_tasks( self, task_option: TaskOption, ) -> dict[RunnerType, list[_ResolvedTask]]: - return { - "generate": self._get_supported_generation_tasks(task_option), - "pooling": self._get_supported_pooling_tasks(task_option), - "draft": ["draft"] - } + if self._is_classify_task(self.architectures): + return {"generate": [], "pooling": ["classify"], "draft": []} + else: + return { + "generate": self._get_supported_generation_tasks(task_option), + "pooling": self._get_supported_pooling_tasks(task_option), + "draft": ["draft"] + } def _get_supported_runner_types( self, @@ -925,12 +951,16 @@ def _resolve_runner( f"Available tasks for runner={task_runner!r}: " f"{supported_tasks[task_runner]}") + if "classify" in supported_tasks.get("pooling", []): + # When multiple pooling tasks are present, default to + # pooling (eg cross-encoder) for non-standard architectures. + return "pooling" + suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [ ("ForCausalLM", "generate"), ("ForConditionalGeneration", "generate"), ("ChatModel", "generate"), ("LMHeadModel", "generate"), - ("ForSequenceClassification", "pooling"), ("EmbeddingModel", "pooling"), ("RewardModel", "pooling"), ] @@ -940,10 +970,6 @@ def _resolve_runner( if arch.endswith(suffix) and pref_runner in supported_runner_types: return pref_runner - if "classify" in supported_tasks.get("pooling", []): - # When multiple pooling tasks are present, default to - # pooling (eg cross-encoder) for non-standard architectures. - return "pooling" if "generate" in supported_runner_types: return "generate" if "pooling" in supported_runner_types: @@ -1525,7 +1551,7 @@ def is_v1_compatible(self) -> bool: @property def is_matryoshka(self) -> bool: - return (hasattr(self.hf_config, "matryoshka_dimensions") + return (bool(getattr(self.hf_config, "matryoshka_dimensions", None)) or getattr(self.hf_config, "is_matryoshka", False)) @property @@ -1539,13 +1565,11 @@ def use_pad_token(self) -> bool: return getattr(self.hf_config, "use_pad_token", True) def get_and_verify_max_len(self, max_model_len: int): - # For pooling models, the tokenizer's `model_max_length` is often a - # reliable source for the maximum sequence length. However, for - # generative models, this can be incorrect and unduly limit the - # context window (e.g., DeepSeek-R1). Therefore, we only consider - # tokenizer_config for pooling models. + # Consider max_model_len in tokenizer_config only when + # pooling models use absolute position_embedding. tokenizer_config = None - if self.runner_type == "pooling": + if (self.runner_type == "pooling" and getattr( + self.hf_config, "position_embedding_type", "") == "absolute"): tokenizer_config = try_get_tokenizer_config( self.tokenizer, trust_remote_code=self.trust_remote_code, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 8e5f332ba7c..01444c35156 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -22,7 +22,8 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model) + as_reward_model, + as_seq_cls_model) from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils import is_pin_memory_available @@ -238,9 +239,28 @@ def get_model_architecture( vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_not_supported = not any(arch in vllm_supported_archs for arch in architectures) + + if vllm_not_supported: + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + + assert model_config.task == "classify" + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + causal_lm_arch_vllm_supported = (causal_lm_arch + in vllm_supported_archs) + + if causal_lm_arch_vllm_supported: + architectures = [causal_lm_arch] + vllm_not_supported = False + break + if (model_config.model_impl == ModelImpl.TRANSFORMERS or model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) + logger.debug_once("Resolve transformers arch %s", str(architectures)) elif (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): @@ -248,12 +268,13 @@ def get_model_architecture( model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": + logger.debug_once("Automatic conversion using `as_embedding_model`.") model_cls = as_embedding_model(model_cls) elif model_config.task == "classify": - # Cannot automatically run as_seq_cls_model, - # otherwise it will cause a circular reference on is_cross_encoder_model - pass + logger.debug_once("Automatic conversion using `as_seq_cls_model`.") + model_cls = as_seq_cls_model(model_cls) elif model_config.task == "reward": + logger.debug_once("Automatic conversion using `as_reward_model`.") model_cls = as_reward_model(model_cls) return model_cls, arch diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index dcdf69f773a..fcc8b2ca275 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -193,6 +193,7 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + score_bias: bool = getattr(config, 'score_bias', False) self.vllm_config = vllm_config self.task = vllm_config.model_config.task @@ -203,7 +204,7 @@ def __init__( config.num_labels, quant_config=quant_config, input_is_parallel=False, - bias=False, + bias=score_bias, prefix=maybe_prefix( prefix, "score")) @@ -349,13 +350,13 @@ def load_weights_using_from_2_way_softmax( false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) - weight = model.lm_head.weight.data[[true_id]].to( + score_weight = model.lm_head.weight.data[[true_id]].to( torch.float32) - model.lm_head.weight.data[[false_id]].to( torch.float32) param = model.score.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, weight) + weight_loader(param, score_weight) del model.lm_head loaded_weights.add("score.weight") @@ -368,6 +369,8 @@ def load_weights_no_post_processing(model, torch.Tensor]]): from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -375,8 +378,6 @@ def load_weights_no_post_processing(model, tokens = cast(list[int], tokens) assert len(tokens) > 0 - device = model.score.weight.device - if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: @@ -394,8 +395,11 @@ def load_weights_no_post_processing(model, trust_remote_code=model_config.trust_remote_code) token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] - score_weight = model.lm_head.weight.data[token_ids].to(device) - model.score.weight.data.copy_(score_weight) + score_weight = model.lm_head.weight.data[token_ids] + + param = model.score.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, score_weight) del model.lm_head loaded_weights.add("score.weight") diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index a43803ed433..522daaa8c50 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -25,7 +25,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import (ClsPooling, SupportsCrossEncoding, SupportsQuant, + SupportsV0Only) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -314,7 +315,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states -class BertModel(nn.Module, SupportsQuant): +class BertModel(nn.Module, SupportsQuant, ClsPooling): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} def __init__(self, @@ -388,7 +389,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant, ClsPooling): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -451,7 +452,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): + SupportsCrossEncoding, SupportsQuant, + ClsPooling): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 0b7350f07d3..cd6f1aae2d5 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -24,7 +24,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsV0Only -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.interfaces import ClsPooling, SupportsQuant from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors @@ -398,7 +398,7 @@ def forward( return hidden_states -class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): +class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant, ClsPooling): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bc8179f886f..59c3102add4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -426,6 +425,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 3a97641aa2f..876c7ab89d7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -642,6 +642,13 @@ def has_step_pooler(model: Union[type[object], object]) -> bool: type(module).__name__ == "StepPool" for module in model.modules()) +class ClsPooling(Protocol): + """The interface required for all models + that using cls as default_pooling_type.""" + + default_pooling_type: ClassVar[str] = "CLS" + + class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2434ac9d205..48ec611df12 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,7 +49,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -646,6 +645,3 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - -LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 9d619b38d38..c3a12cc4f56 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -21,7 +21,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import ClsPooling, SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -196,7 +196,7 @@ def forward( @support_torch_compile -class ModernBertModel(nn.Module): +class ModernBertModel(nn.Module, ClsPooling): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."}) @@ -278,7 +278,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding): + SupportsCrossEncoding, ClsPooling): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4..23f65b99c22 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,7 +50,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -496,6 +495,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index de99a76f289..393ce41a91a 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -44,7 +44,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model @@ -320,6 +319,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 79190860ac9..7313bd1e49a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -138,7 +138,6 @@ "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), - "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"), @@ -181,10 +180,6 @@ "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] - "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 - "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 - "LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501 "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } @@ -273,11 +268,12 @@ ] -@dataclass(frozen=True) +@dataclass() class _ModelInfo: architecture: str is_text_generation_model: bool is_pooling_model: bool + default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool @@ -295,6 +291,8 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), is_pooling_model=True, # Can convert any model into a pooling model + default_pooling_type=getattr(model, "default_pooling_type", + "LAST"), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), @@ -459,10 +457,19 @@ def _try_load_model_cls(self, return _try_load_model_cls(model_arch, self.models[model_arch]) def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: - if model_arch not in self.models: - return None + if model_arch in self.models: + return _try_inspect_model_cls(model_arch, self.models[model_arch]) - return _try_inspect_model_cls(model_arch, self.models[model_arch]) + if model_arch.endswith("ForSequenceClassification"): + arch = model_arch.replace("ForSequenceClassification", + "ForCausalLM") + if arch not in self.models: + return None + info = _try_inspect_model_cls(model_arch, self.models[arch]) + info.supports_cross_encoding = True + return info + + return None def _normalize_archs( self, @@ -477,6 +484,15 @@ def _normalize_archs( normalized_arch = list( filter(lambda model: model in self.models, architectures)) + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + if causal_lm_arch in self.models: + normalized_arch.append(arch) + # make sure Transformers backend is put at the last as a fallback if len(normalized_arch) != len(architectures): normalized_arch.append("TransformersForCausalLM") diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 1d3a23a5e54..7fbdbfcb588 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -19,7 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import ClsPooling, SupportsCrossEncoding, SupportsV0Only class RobertaEmbedding(nn.Module): @@ -154,7 +154,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): + SupportsV0Only, ClsPooling): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -164,7 +164,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, roberta: An instance of BertModel used for forward operations. _pooler: An instance of Pooler used for pooling operations. """ - jina_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ 'emb_ln': "embeddings.LayerNorm",