Skip to content

Commit 23a04e0

Browse files
authored
[Fix] Support cls pooling in ModernBertPooler (#20067)
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
1 parent 02c97d9 commit 23a04e0

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/model_executor/models/modernbert.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,21 @@ def __init__(self, config: ModernBertConfig):
258258
super().__init__()
259259
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
260260
config.classifier_bias)
261+
self.pooling_type = config.classifier_pooling
261262
self.act = nn.GELU()
262263
self.norm = nn.LayerNorm(config.hidden_size,
263264
eps=config.norm_eps,
264265
bias=config.norm_bias)
265266

266267
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
267268
pooled_output = hidden_states
268-
pooled_output = pooled_output.mean(dim=0, keepdim=False)
269+
if self.pooling_type == "mean":
270+
pooled_output = pooled_output.mean(dim=0, keepdim=False)
271+
elif self.pooling_type == "cls":
272+
pooled_output = pooled_output[0, :]
273+
else:
274+
raise ValueError("Pooling type should be either `cls` or `mean`, "
275+
f"but got {self.pooling_type}")
269276
pooled_output = self.norm(self.act(self.dense(pooled_output)))
270277
return pooled_output
271278

0 commit comments

Comments
 (0)