Skip to content

Commit 1c3198b

Browse files
[Model] Consolidate pooler implementations (#20927)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 260127e commit 1c3198b

File tree

9 files changed

+553
-367
lines changed

9 files changed

+553
-367
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 434 additions & 247 deletions
Large diffs are not rendered by default.

vllm/model_executor/models/adapters.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,27 @@ def __init__(
5858
) -> None:
5959
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
6060

61+
self.vllm_config = vllm_config
62+
6163
# These are not used in pooling models
6264
for attr in ("lm_head", "logits_processor"):
6365
if hasattr(self, attr):
6466
delattr(self, attr)
6567

68+
# If the model already defines a pooler instance, don't overwrite it
69+
if not getattr(self, "_pooler", None):
70+
self._init_pooler(vllm_config, prefix=prefix)
71+
72+
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
6673
pooler_config = vllm_config.model_config.pooler_config
6774
assert pooler_config is not None
6875

69-
# If the model already defines a pooler instance, don't overwrite it
70-
if not getattr(self, "_pooler", None):
71-
self._pooler = Pooler.from_config_with_defaults(
72-
pooler_config,
73-
pooling_type=default_pooling_type,
74-
normalize=default_normalize,
75-
softmax=default_softmax,
76-
)
76+
self._pooler = Pooler.from_config_with_defaults(
77+
pooler_config,
78+
pooling_type=default_pooling_type,
79+
normalize=default_normalize,
80+
softmax=default_softmax,
81+
)
7782

7883
def pooler(
7984
self,
@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
165170

166171
# Lazy import
167172
from vllm.model_executor.layers.linear import RowParallelLinear
168-
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
173+
from vllm.model_executor.layers.pooler import (ClassifierPooler,
174+
PoolerOutput, PoolingType,
175+
SimplePooler)
169176
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
170177
from vllm.model_executor.pooling_metadata import PoolingMetadata
171178
from vllm.sequence import IntermediateTensors
@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
182189
class ModelForSequenceClassification(ModelForPooling,
183190
SupportsCrossEncoding):
184191

185-
def __init__(
186-
self,
187-
*,
188-
vllm_config: "VllmConfig",
189-
prefix: str = "",
190-
**kwargs: Any,
191-
) -> None:
192-
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
193-
192+
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
194193
config = vllm_config.model_config.hf_config
195194
quant_config = vllm_config.quant_config
196195

197-
self.vllm_config = vllm_config
198-
self.task = vllm_config.model_config.task
199-
self.pooling_type = (
200-
vllm_config.model_config.pooler_config.pooling_type)
201-
202-
self.score = RowParallelLinear(config.hidden_size,
203-
config.num_labels,
204-
quant_config=quant_config,
205-
input_is_parallel=False,
206-
bias=False,
207-
prefix=maybe_prefix(
208-
prefix, "score"))
196+
self.score = RowParallelLinear(
197+
config.hidden_size,
198+
config.num_labels,
199+
input_is_parallel=False,
200+
bias=False,
201+
params_dtype=torch.float32,
202+
quant_config=quant_config,
203+
prefix=maybe_prefix(prefix, "score"),
204+
)
205+
206+
pooler_config = vllm_config.model_config.pooler_config
207+
assert pooler_config is not None
208+
209+
pooler = SimplePooler.from_config_with_defaults(
210+
pooler_config,
211+
pooling_type=PoolingType.LAST,
212+
normalize=False,
213+
softmax=True,
214+
)
215+
216+
self._pooler = ClassifierPooler(
217+
vllm_config.model_config,
218+
pooling=pooler.pooling,
219+
classifier=self._classifier,
220+
act_fn=pooler.head.activation,
221+
)
222+
223+
def _classifier(self, x: torch.Tensor):
224+
x, _ = self.score(x.float())
225+
return x
209226

210227
def forward(
211228
self,
@@ -222,27 +239,7 @@ def pooler(
222239
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
223240
pooling_metadata: PoolingMetadata,
224241
) -> PoolerOutput:
225-
226-
def get_logits(hidden_states):
227-
if isinstance(hidden_states, list):
228-
logits = [self.score(state)[0] for state in hidden_states]
229-
else:
230-
logits, _ = self.score(hidden_states)
231-
return logits
232-
233-
if self.pooling_type == PoolingType.ALL:
234-
logits = get_logits(hidden_states)
235-
return self._pooler(logits, pooling_metadata)
236-
else:
237-
hidden_states = self._pooler.extract_states(
238-
hidden_states, pooling_metadata)
239-
logits = get_logits(hidden_states)
240-
pooled_data = self._pooler.head(logits, pooling_metadata)
241-
242-
pooled_outputs = [
243-
self._pooler.build_output(data) for data in pooled_data
244-
]
245-
return PoolerOutput(outputs=pooled_outputs)
242+
return self._pooler(hidden_states, pooling_metadata)
246243

247244
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
248245
tokens = getattr(self.config, "classifier_from_token", None)

vllm/model_executor/models/bert.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5-
from typing import Optional
5+
from typing import Optional, Union
66

77
import torch
88
from torch import nn
@@ -18,7 +18,7 @@
1818
QKVParallelLinear,
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
21-
PoolingType)
21+
PoolingMethod, PoolingType)
2222
from vllm.model_executor.layers.quantization import QuantizationConfig
2323
from vllm.model_executor.layers.vocab_parallel_embedding import (
2424
VocabParallelEmbedding)
@@ -84,14 +84,18 @@ class BertPooler(nn.Module):
8484

8585
def __init__(self, config: BertConfig):
8686
super().__init__()
87+
88+
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
8789
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
8890
self.activation = nn.Tanh()
8991

90-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
91-
# We "pool" the model by simply taking the hidden state corresponding
92-
# to the first token.
93-
first_token_tensor = hidden_states[0, :]
94-
pooled_output = self.dense(first_token_tensor)
92+
def forward(
93+
self,
94+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
95+
pooling_metadata: PoolingMetadata,
96+
) -> Union[torch.Tensor, list[torch.Tensor]]:
97+
pooled_output = self.pooling(hidden_states, pooling_metadata)
98+
pooled_output = self.dense(pooled_output)
9599
pooled_output = self.activation(pooled_output)
96100
return pooled_output
97101

@@ -472,8 +476,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
472476
embedding_class=BertEmbedding,
473477
add_pooling_layer=True)
474478
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
475-
self._pooler = ClassifierPooler(vllm_config.model_config,
476-
self.classifier, self.bert.pooler)
479+
self._pooler = ClassifierPooler(
480+
vllm_config.model_config,
481+
pooling=self.bert.pooler,
482+
classifier=self.classifier,
483+
)
477484

478485
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
479486
loader = AutoWeightsLoader(self)

vllm/model_executor/models/gritlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from vllm.config import ModelConfig, VllmConfig
1111
from vllm.logger import init_logger
12-
from vllm.model_executor.layers.pooler import PoolerHead
12+
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
1313
from vllm.model_executor.models.llama import LlamaForCausalLM
1414
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
1515
PoolingTensors)
@@ -49,7 +49,7 @@ def tokens_to_ids(tokens: list[str]) -> array:
4949
self.embed_pattern_ids = tokens_to_ids(
5050
["▁<", "|", "embed", "|", ">", "<0x0A>"])
5151

52-
self.head = PoolerHead(normalize=True, softmax=False)
52+
self.head = PoolerHead(PoolerNormalize())
5353

5454
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
5555
"""

vllm/model_executor/models/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def supports_cross_encoding(
659659
def has_step_pooler(model: Union[type[object], object]) -> bool:
660660
"""Check if the model uses step pooler."""
661661
return is_pooling_model(model) and any(
662-
type(module).__name__ == "StepPool" for module in model.modules())
662+
type(module).__name__ == "StepPooler" for module in model.modules())
663663

664664

665665
class SupportsQuant:

vllm/model_executor/models/jamba.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2121
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
22-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
22+
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
23+
SimplePooler)
2324
from vllm.model_executor.layers.quantization import QuantizationConfig
2425
from vllm.model_executor.layers.vocab_parallel_embedding import (
2526
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
564565

565566
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
566567
super().__init__(vllm_config=vllm_config, prefix=prefix)
568+
567569
config = vllm_config.model_config.hf_config
568570
num_labels: int = config.num_labels
569571
score_bias: bool = getattr(config, 'score_bias', False)
570-
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
572+
573+
# TODO: The original reward weights have float32 accuracy data, we
574+
# would like to load them in fp32 to get that extra precision.
575+
# Currently weight_loader passes the weight which is already in bf16
576+
self.score = nn.Linear(
577+
config.hidden_size,
578+
num_labels,
579+
bias=score_bias,
580+
dtype=torch.float32,
581+
)
571582

572583
pooler_config = vllm_config.model_config.pooler_config
573-
self._pooler = Pooler.from_config_with_defaults(
584+
assert pooler_config is not None
585+
586+
pooler = SimplePooler.from_config_with_defaults(
574587
pooler_config,
575588
pooling_type=PoolingType.LAST,
576589
normalize=False,
577-
softmax=False)
590+
softmax=False,
591+
)
592+
593+
self._pooler = ClassifierPooler(
594+
vllm_config.model_config,
595+
pooling=pooler.pooling,
596+
classifier=self.score,
597+
act_fn=pooler.head.activation,
598+
)
578599

579600
def pooler(
580601
self,
581602
hidden_states: torch.Tensor,
582603
pooling_metadata: PoolingMetadata,
583604
) -> Optional[PoolerOutput]:
584-
hidden_states = hidden_states.float()
585-
logits = self.score(hidden_states)
586-
return self._pooler(logits, pooling_metadata)
587-
588-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
589-
# TODO: The reward weights themselves have float32 accuracy data, we
590-
# would like to load them in fp32 to get that extra precision.
591-
super().load_weights(weights)
592-
self.score = self.score.float()
605+
return self._pooler(hidden_states, pooling_metadata)

vllm/model_executor/models/modernbert.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable
4-
from typing import Optional
4+
from typing import Optional, Union
55

66
import torch
77
from torch import nn
@@ -13,7 +13,8 @@
1313
from vllm.distributed import get_tensor_model_parallel_world_size
1414
from vllm.model_executor.layers.linear import (QKVParallelLinear,
1515
RowParallelLinear)
16-
from vllm.model_executor.layers.pooler import ClassifierPooler
16+
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
17+
PoolingMethod, PoolingType)
1718
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
1819
from vllm.model_executor.layers.vocab_parallel_embedding import (
1920
VocabParallelEmbedding)
@@ -252,10 +253,13 @@ def forward(
252253
return norm_outputs
253254

254255

255-
class ModernBertPooler(nn.Module):
256+
class ModernBertPooler(BasePooler):
256257

257258
def __init__(self, config: ModernBertConfig):
258259
super().__init__()
260+
261+
pooling_type = PoolingType[config.classifier_pooling.upper()]
262+
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
259263
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
260264
config.classifier_bias)
261265
self.pooling_type = config.classifier_pooling
@@ -264,15 +268,12 @@ def __init__(self, config: ModernBertConfig):
264268
eps=config.norm_eps,
265269
bias=config.norm_bias)
266270

267-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
268-
pooled_output = hidden_states
269-
if self.pooling_type == "mean":
270-
pooled_output = pooled_output.mean(dim=0, keepdim=False)
271-
elif self.pooling_type == "cls":
272-
pooled_output = pooled_output[0, :]
273-
else:
274-
raise ValueError("Pooling type should be either `cls` or `mean`, "
275-
f"but got {self.pooling_type}")
271+
def forward(
272+
self,
273+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
274+
pooling_metadata: PoolingMetadata,
275+
) -> Union[torch.Tensor, list[torch.Tensor]]:
276+
pooled_output = self.pooling(hidden_states, pooling_metadata)
276277
pooled_output = self.norm(self.act(self.dense(pooled_output)))
277278
return pooled_output
278279

@@ -287,9 +288,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
287288
self.model = ModernBertModel(vllm_config=vllm_config,
288289
prefix=maybe_prefix(prefix, "modernbert"))
289290
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
290-
self._pooler = ClassifierPooler(vllm_config.model_config,
291-
self.classifier,
292-
ModernBertPooler(config))
291+
self._pooler = ClassifierPooler(
292+
vllm_config.model_config,
293+
pooling=ModernBertPooler(config),
294+
classifier=self.classifier,
295+
)
293296

294297
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
295298

vllm/model_executor/models/roberta.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import RobertaConfig
1010

1111
from vllm.config import VllmConfig
12-
from vllm.model_executor.layers.pooler import ClassifierPooler
12+
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
1313
from vllm.model_executor.layers.vocab_parallel_embedding import (
1414
VocabParallelEmbedding)
1515
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
@@ -106,8 +106,8 @@ def __init__(self, config: RobertaConfig):
106106
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
107107
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
108108

109-
def forward(self, features, **kwargs):
110-
x = features[0, :] # take <s> token (equiv. to [CLS])
109+
def forward(self, x: torch.Tensor) -> torch.Tensor:
110+
# CLSPool has already been applied in `pooling`
111111
x = self.dense(x)
112112
x = torch.tanh(x)
113113
x = self.out_proj(x)
@@ -188,8 +188,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
188188
add_pooling_layer=False)
189189
self.classifier = RobertaClassificationHead(config)
190190

191-
self._pooler = ClassifierPooler(vllm_config.model_config,
192-
self.classifier)
191+
self._pooler = ClassifierPooler(
192+
vllm_config.model_config,
193+
pooling=CLSPool(),
194+
classifier=self.classifier,
195+
)
193196

194197
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
195198
loader = AutoWeightsLoader(self)

0 commit comments

Comments
 (0)