Skip to content

Commit bd4c1e6

Browse files
authored
Support for LlamaForSequenceClassification (#20807)
Signed-off-by: thechaos16 <thechaos16@gmail.com>
1 parent 99b4f08 commit bd4c1e6

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def check_available_online(
330330
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
331331
"classifier_from_token": ["Yes"], # noqa: E501
332332
"method": "no_post_processing"}), # noqa: E501
333+
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
333334
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
334335
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
335336
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from vllm.model_executor.sampling_metadata import SamplingMetadata
5050
from vllm.sequence import IntermediateTensors
5151

52+
from .adapters import as_seq_cls_model
5253
from .interfaces import SupportsLoRA, SupportsPP
5354
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5455
is_pp_missing_parameter,
@@ -645,3 +646,6 @@ def permute(w: torch.Tensor, n_heads: int):
645646
name = name.replace(item, mapping[item])
646647

647648
return name, loaded_weight
649+
650+
651+
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@
183183
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
184184
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
185185
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
186-
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
186+
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
187+
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
187188
}
188189

189190
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)