Skip to content

Commit f1e840e

Browse files
[Model] GPT2ForSequenceClassification model (#19663)
Signed-off-by: nie3e <adrcwiek@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 7771d1d commit f1e840e

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def check_available_online(
267267
# [Text-only]
268268
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
269269
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
270+
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
270271
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
271272
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
272273
trust_remote_code=True),

vllm/model_executor/models/gpt2.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from vllm.model_executor.layers.vocab_parallel_embedding import (
4141
ParallelLMHead, VocabParallelEmbedding)
4242
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43+
from vllm.model_executor.pooling_metadata import PoolingMetadata
4344
from vllm.model_executor.sampling_metadata import SamplingMetadata
44-
from vllm.sequence import IntermediateTensors
45+
from vllm.sequence import IntermediateTensors, PoolerOutput
4546

47+
from ..layers.pooler import Pooler, PoolingType
4648
from .interfaces import SupportsPP
4749
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4850
make_empty_intermediate_tensors_factory, make_layers,
@@ -318,6 +320,58 @@ def load_weights(self, weights: Iterable[tuple[str,
318320
return loader.load_weights(weights)
319321

320322

323+
class GPT2ForSequenceClassification(nn.Module):
324+
"""GPT2 Model for sequence classification.
325+
326+
This class expands GPT2Model with pooling and score functions - last token
327+
is being used for classification.
328+
329+
Attributes:
330+
transformer: An instance of GPT2Model used for forward operations.
331+
score: A layer for calculating logits.
332+
_pooler: An instance of Pooler used for pooling operations.
333+
"""
334+
335+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336+
super().__init__()
337+
config = vllm_config.model_config.hf_config
338+
self.transformer = GPT2Model(vllm_config=vllm_config,
339+
prefix=maybe_prefix(prefix, "gpt2"))
340+
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
341+
pooler_config = vllm_config.model_config.pooler_config
342+
self._pooler = Pooler.from_config_with_defaults(
343+
pooler_config,
344+
pooling_type=PoolingType.LAST,
345+
normalize=False,
346+
softmax=True)
347+
348+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
349+
loader = AutoWeightsLoader(self)
350+
return loader.load_weights(weights)
351+
352+
def pooler(
353+
self,
354+
hidden_states: torch.Tensor,
355+
pooling_metadata: PoolingMetadata,
356+
) -> Optional[PoolerOutput]:
357+
return self._pooler(hidden_states, pooling_metadata)
358+
359+
def forward(
360+
self,
361+
input_ids: torch.Tensor,
362+
positions: torch.Tensor,
363+
intermediate_tensors: Optional[IntermediateTensors] = None,
364+
inputs_embeds: Optional[torch.Tensor] = None,
365+
) -> torch.Tensor:
366+
hidden_states = self.transformer(
367+
input_ids=input_ids,
368+
position_ids=positions,
369+
inputs_embeds=inputs_embeds,
370+
intermediate_tensors=intermediate_tensors)
371+
logits = self.score(hidden_states)
372+
return logits
373+
374+
321375
def _add_transformer_prefix(
322376
weights: Iterable[tuple[str, torch.Tensor]]
323377
) -> Iterable[tuple[str, torch.Tensor]]:

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
131131
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
132132
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
133+
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
133134
"GritLM": ("gritlm", "GritLM"),
134135
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
135136
"GteNewModel": ("bert_with_rope", "GteNewModel"),

0 commit comments

Comments
 (0)