Skip to content

Commit 5111642

Browse files
authored
[Doc] Update V1 status for decoder-only embedding models (#19952)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 1bcd15e commit 5111642

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

docs/models/supported_models.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -407,15 +407,15 @@ Specified using `--task embed`.
407407
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
408408
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|-----------------------|
409409
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
410-
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | |
410+
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
411411
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
412412
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. || | |
413413
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. ||| |
414414
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. ||| |
415415
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. ||| |
416-
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | |
417-
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | |
418-
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | |
416+
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
417+
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
418+
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
419419
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
420420

421421
!!! note
@@ -442,9 +442,10 @@ Specified using `--task reward`.
442442

443443
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
444444
|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------|
445-
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | |
446-
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | |
447-
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | |
445+
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
446+
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
447+
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
448+
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
448449

449450
If your model is not in the above list, we will try to automatically convert the model using
450451
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
@@ -460,7 +461,7 @@ Specified using `--task classify`.
460461
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
461462
|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------|-----------------------|
462463
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
463-
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | |
464+
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
464465
If your model is not in the above list, we will try to automatically convert the model using
465466
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
466467

@@ -471,7 +472,7 @@ Specified using `--task score`.
471472
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
472473
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
473474
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
474-
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | |
475+
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
475476
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
476477
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
477478

vllm/model_executor/models/qwen2_rm.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,12 @@
1919
from vllm.model_executor.pooling_metadata import PoolingMetadata
2020
from vllm.sequence import IntermediateTensors, PoolerOutput
2121

22-
from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only
22+
from .interfaces import SupportsLoRA, SupportsPP
2323
from .qwen2 import Qwen2Model
2424
from .utils import AutoWeightsLoader, maybe_prefix
2525

2626

27-
class ReLU(nn.Module):
28-
29-
def __init__(self):
30-
super().__init__()
31-
self.activation = nn.ReLU()
32-
33-
def forward(self, input):
34-
input, _ = input
35-
return self.activation(input)
36-
37-
38-
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
39-
SupportsV0Only):
27+
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
4028
packed_modules_mapping = {
4129
"qkv_proj": [
4230
"q_proj",
@@ -65,11 +53,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6553
self.score = nn.Sequential(
6654
ColumnParallelLinear(config.hidden_size,
6755
config.hidden_size,
68-
quant_config=quant_config),
69-
ReLU(),
56+
quant_config=quant_config,
57+
return_bias=False),
58+
nn.ReLU(),
7059
RowParallelLinear(config.hidden_size,
7160
config.num_labels,
72-
quant_config=quant_config),
61+
quant_config=quant_config,
62+
return_bias=False),
7363
)
7464
self._pooler: SimplePooler
7565
self.make_empty_intermediate_tensors = (
@@ -87,7 +77,7 @@ def forward(
8777
) -> Union[torch.Tensor, IntermediateTensors]:
8878
hidden_states = self.model(input_ids, positions, intermediate_tensors,
8979
inputs_embeds)
90-
logits, _ = self.score(hidden_states)
80+
logits = self.score(hidden_states)
9181
return logits
9282

9383
def pooler(

0 commit comments

Comments
 (0)